Skip to content

Commit

Permalink
Merge pull request #23 from tradewelltech/simplify-kafka-internals
Browse files Browse the repository at this point in the history
Simplify kafka internals
  • Loading branch information
0x26res authored Aug 21, 2023
2 parents a109054 + 86b924f commit d0df4e6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
63 changes: 38 additions & 25 deletions beavers/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from enum import Enum
from typing import Any, AnyStr, Generic, Optional, Protocol, TypeVar
from typing import Any, AnyStr, Generic, Optional, Protocol, Sequence, TypeVar

import confluent_kafka
import confluent_kafka.admin
Expand All @@ -22,7 +22,7 @@
class KafkaMessageDeserializer(Protocol[T]):
"""Interface for converting incoming kafka messages to custom data."""

def __call__(self, messages: list[confluent_kafka.Message]) -> T:
def __call__(self, messages: Sequence[confluent_kafka.Message]) -> T:
"""Convert batch of messages to data."""


Expand All @@ -38,7 +38,7 @@ class KafkaProducerMessage:
class KafkaMessageSerializer(Protocol[T]):
"""Interface for converting custom data to outgoing kafka messages."""

def __call__(self, value: T) -> list[KafkaProducerMessage]:
def __call__(self, value: T) -> Sequence[KafkaProducerMessage]:
"""Convert batch of custom data to `KafkaProducerMessage`."""


Expand Down Expand Up @@ -384,29 +384,14 @@ def _update_partition_info(self, new_messages: list[confluent_kafka.Message]):
)


@dataclasses.dataclass(frozen=True)
class _RuntimeSinkTopic:
nodes: list[Node]
serializer: KafkaMessageSerializer

def flush(self, cycle_id: int, producer_manger: _ProducerManager):
for node in self.nodes:
if node.get_cycle_id() == cycle_id:
node_value = node.get_sink_value()
# TODO: capture serialization time in metrics
messages = self.serializer(node_value)
for message in messages:
producer_manger.produce_one(
message.topic, message.key, message.value
)


@dataclasses.dataclass
class ExecutionMetrics:
"""Metrics for the execution of a dag."""

serialization_ns: int = 0
serialization_count: int = 0
deserialization_ns: int = 0
deserialization_count: int = 0
execution_ns: int = 0
execution_count: int = 0

Expand All @@ -419,6 +404,15 @@ def measure_serialization_time(self):
self.serialization_ns += time.time_ns() - before
self.serialization_count += 1

@contextlib.contextmanager
def measure_deserialization_time(self):
before = time.time_ns()
try:
yield
finally:
self.deserialization_ns += time.time_ns() - before
self.deserialization_count += 1

@contextlib.contextmanager
def measure_execution_time(self):
before = time.time_ns()
Expand All @@ -429,6 +423,20 @@ def measure_execution_time(self):
self.execution_count += 1


@dataclasses.dataclass(frozen=True)
class _RuntimeSinkTopic:
nodes: list[Node]
serializer: KafkaMessageSerializer

def serialize(self, cycle_id: int) -> list[KafkaProducerMessage]:
messages = []
for node in self.nodes:
if node.get_cycle_id() == cycle_id:
node_value = node.get_sink_value()
messages.extend(self.serializer(node_value))
return messages


class KafkaDriver:
"""Control the execution of a dag, using data from kafka."""

Expand Down Expand Up @@ -519,13 +527,19 @@ def _process_message(self, message: confluent_kafka.Message):
self._source_topics[message.topic()].append(message)

def _produce_records(self, cycle_id: int):
for sink_topic in self._sink_topics:
sink_topic.flush(cycle_id, self._producer_manager)
messages = []
with self._metrics.measure_serialization_time():
for sink_topic in self._sink_topics:
messages.extend(sink_topic.serialize(cycle_id))
for message in messages:
self._producer_manager.produce_one(
message.topic, message.key, message.value
)

def _run_cycle(self, messages: list[confluent_kafka.Message]) -> bool:
has_messages = False
with self._metrics.measure_serialization_time():
self._process_messages(messages)
self._process_messages(messages)
with self._metrics.measure_deserialization_time():
for handler in self._source_topics.values():
has_messages = handler.flush() or has_messages
cycle_time = (
Expand Down Expand Up @@ -652,7 +666,6 @@ def _get_previous_start_of_day(
if (local_now - local_now.normalize()) > start_of_day_time:
return (local_now.normalize() + start_of_day_time).tz_convert("UTC")
else:
# TODO: consider adding calendar?
return (
local_now.normalize() - pd.to_timedelta("1d") + start_of_day_time
).tz_convert("UTC")
Expand Down
15 changes: 6 additions & 9 deletions tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,10 @@ def test_kafka_driver_word_count(log_helper: LogHelper):
assert len(log_helper.flush()) == 1

metrics = kafka_driver.flush_metrics()
assert metrics.deserialization_ns > 0
assert metrics.deserialization_count == 6
assert metrics.serialization_ns > 0
assert metrics.serialization_count == 6
assert metrics.serialization_count == 3
assert metrics.execution_ns > 0
assert metrics.execution_count == 3

Expand Down Expand Up @@ -1308,19 +1310,14 @@ def test_runtime_sink_topic():
sink = dag.sink("sink", node)
runtime_sink_topic = _RuntimeSinkTopic([sink], WorldCountSerializer("topic-1"))

producer_manager = MockProducerManager()
dag.execute()
runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager)
assert producer_manager.messages == []
assert runtime_sink_topic.serialize(dag.get_cycle_id()) == []

node.set_stream({"foo": "bar"})
dag.execute()
runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager)
assert producer_manager.messages == [
assert runtime_sink_topic.serialize(dag.get_cycle_id()) == [
KafkaProducerMessage(topic="topic-1", key=b"foo", value=b"bar")
]
producer_manager.messages.clear()

dag.execute()
runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager)
assert producer_manager.messages == []
assert runtime_sink_topic.serialize(dag.get_cycle_id()) == []

0 comments on commit d0df4e6

Please sign in to comment.