Skip to content

Commit

Permalink
fix: Dynamodb deduplicate batch write request by partition keys (#2515)
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Trejo <[email protected]>
  • Loading branch information
TremaMiguel authored Apr 9, 2022
1 parent 6bf8df0 commit 70d4a13
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
42 changes: 27 additions & 15 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,7 @@ def online_write_batch(
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)
with table_instance.batch_writer() as batch:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(entity_key)
batch.put_item(
Item={
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}
)
if progress:
progress(1)
self._write_batch_non_duplicates(table_instance, data, progress)

@log_exceptions_and_usage(online_store="dynamodb")
def online_read(
Expand Down Expand Up @@ -299,6 +285,32 @@ def _sort_dynamodb_response(self, responses: list, order: list):
_, table_responses_ordered = zip(*table_responses_ordered)
return table_responses_ordered

@log_exceptions_and_usage(online_store="dynamodb")
def _write_batch_non_duplicates(
self,
table_instance,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
):
"""Deduplicate write batch request items on ``entity_id`` primary key."""
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(entity_key)
batch.put_item(
Item={
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}
)
if progress:
progress(1)


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client("dynamodb", region_name=region, endpoint_url=endpoint_url)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
from dataclasses import dataclass

import boto3
import pytest
from moto import mock_dynamodb2

Expand Down Expand Up @@ -162,7 +164,7 @@ def test_online_read(repo_config, n_samples):
data = _create_n_customer_test_samples(n=n_samples)
_insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_{n_samples}", REGION)

entity_keys, features = zip(*data)
entity_keys, features, *rest = zip(*data)
dynamodb_store = DynamoDBOnlineStore()
returned_items = dynamodb_store.online_read(
config=repo_config,
Expand All @@ -171,3 +173,24 @@ def test_online_read(repo_config, n_samples):
)
assert len(returned_items) == len(data)
assert [item[1] for item in returned_items] == list(features)


@mock_dynamodb2
def test_write_batch_non_duplicates(repo_config):
"""Test DynamoDBOnline Store deduplicate write batch request items."""
dynamodb_tbl = f"{TABLE_NAME}_batch_non_duplicates"
_create_test_table(PROJECT, dynamodb_tbl, REGION)
data = _create_n_customer_test_samples()
data_duplicate = deepcopy(data)
dynamodb_resource = boto3.resource("dynamodb", region_name=REGION)
table_instance = dynamodb_resource.Table(f"{PROJECT}.{dynamodb_tbl}")
dynamodb_store = DynamoDBOnlineStore()
# Insert duplicate data
dynamodb_store._write_batch_non_duplicates(
table_instance, data + data_duplicate, progress=None
)
# Request more items than inserted
response = table_instance.scan(Limit=20)
returned_items = response.get("Items", None)
assert returned_items is not None
assert len(returned_items) == len(data)
6 changes: 4 additions & 2 deletions sdk/python/tests/utils/online_store_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def _create_n_customer_test_samples(n=10):
"name": ValueProto(string_val="John"),
"age": ValueProto(int64_val=3),
},
datetime.utcnow(),
None,
)
for i in range(n)
]
Expand All @@ -42,13 +44,13 @@ def _delete_test_table(project, tbl_name, region):
def _insert_data_test_table(data, project, tbl_name, region):
dynamodb_resource = boto3.resource("dynamodb", region_name=region)
table_instance = dynamodb_resource.Table(f"{project}.{tbl_name}")
for entity_key, features in data:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(entity_key)
with table_instance.batch_writer() as batch:
batch.put_item(
Item={
"entity_id": entity_id,
"event_ts": str(utils.make_tzaware(datetime.utcnow())),
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {k: v.SerializeToString() for k, v in features.items()},
}
)

0 comments on commit 70d4a13

Please sign in to comment.