Skip to content

Commit

Permalink
Add kafka json to arrow support (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x26res authored Dec 5, 2023
1 parent 9a060fd commit 120c116
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
49 changes: 49 additions & 0 deletions beavers/pyarrow_kafka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import dataclasses
import io
import json

import confluent_kafka
import pyarrow as pa
import pyarrow.json

from beavers.kafka import (
KafkaMessageDeserializer,
KafkaMessageSerializer,
KafkaProducerMessage,
)


@dataclasses.dataclass(frozen=True)
class JsonDeserializer(KafkaMessageDeserializer[pa.Table]):
schema: pa.Schema

def __call__(self, messages: confluent_kafka.Message) -> pa.Table:
if messages:
with io.BytesIO() as buffer:
for message in messages:
buffer.write(message.value())
buffer.write(b"\n")
buffer.seek(0)
return pyarrow.json.read_json(
buffer,
parse_options=pyarrow.json.ParseOptions(
explicit_schema=self.schema
),
)
else:
return self.schema.empty_table()


@dataclasses.dataclass(frozen=True)
class JsonSerializer(KafkaMessageSerializer[pa.Table]):
topic: str

def __call__(self, table: pa.Table):
return [
KafkaProducerMessage(
self.topic,
key=None,
value=json.dumps(message, default=str).encode("utf-8"),
)
for message in table.to_pylist()
]
18 changes: 18 additions & 0 deletions tests/test_pyarrow_kafka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from beavers.pyarrow_kafka import JsonDeserializer, JsonSerializer
from tests.test_kafka import mock_kafka_message
from tests.test_util import TEST_TABLE


def test_json_deserializer_empty():
deserializer = JsonDeserializer(TEST_TABLE.schema)
assert deserializer([]) == TEST_TABLE.schema.empty_table()


def test_end_to_end():
deserializer = JsonDeserializer(TEST_TABLE.schema)
serializer = JsonSerializer("topic-1")
out_messages = serializer(TEST_TABLE)
in_messages = [
mock_kafka_message(topic=m.topic, value=m.value) for m in out_messages
]
assert deserializer(in_messages) == TEST_TABLE
11 changes: 1 addition & 10 deletions tests/test_pyarrow_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@

from beavers.engine import UTC_MAX
from beavers.pyarrow_replay import ArrowTableDataSink, ArrowTableDataSource

TEST_TABLE = pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-01T00:00:00Z"),
pd.to_datetime("2023-01-02T00:00:00Z"),
],
"value": [1, 2],
}
)
from tests.test_util import TEST_TABLE


def test_arrow_table_data_source():
Expand Down
11 changes: 11 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
from typing import Callable, Dict, Generic, TypeVar

import pandas as pd
import pyarrow as pa

from beavers.engine import UTC_MAX, Dag, TimerManager
from beavers.replay import DataSink, DataSource

T = TypeVar("T")

TEST_TABLE = pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-01T00:00:00Z"),
pd.to_datetime("2023-01-02T00:00:00Z"),
],
"value": [1, 2],
}
)


class GetLatest(Generic[T]):
def __init__(self, default: T):
Expand Down

0 comments on commit 120c116

Please sign in to comment.