Skip to content

Commit

Permalink
[live-data-fetcher] Distribute read load across objects in S3 (#817)
Browse files Browse the repository at this point in the history
* Add logic to distribute read load across different files in S3

* lint

* add test for live data watcher

* intentionally fail test to verify

* intentionally fail test to verify

* remove logs

* verify this fails

* and it does

* change range

* change range

* change range

* comments

* comments

* use function

* move more things into helper

* fix unit test

* fix unit test

* rename unit test

* rename unit test

* import order

* add logging since a test failed

* import order

* move logger up

* oops, the delimiter should have been / not _

* add 0 and 1 to no sharding unit test

* clean up python

* clean up python

* comments

* add negative values to test

* revert retries from 10 -> 5

---------

Co-authored-by: Siddharth Manoj <[email protected]>
  • Loading branch information
2 people authored and areitz committed Nov 30, 2023
1 parent f87e8d2 commit 16265f0
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 9 deletions.
26 changes: 25 additions & 1 deletion baseplate/sidecars/live_data_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import random
import sys
import time

Expand Down Expand Up @@ -115,17 +116,38 @@ def _parse_loader_type(data: bytes) -> LoaderType:
return LoaderType(loader_type)


def _generate_sharded_file_key(num_file_shards: Optional[int], file_key: str) -> str:
# We can't assume that every ZK Node that is being NodeWatched by the live-data-fetcher
# will make use of S3 prefix sharding - but, we know at least one does (/experiments).
# If it's not present or the value is less than 2, set the prefix to empty string ""
sharded_file_key_prefix = ""
if num_file_shards is not None and num_file_shards > 1:
# If the num_file_shards key is present, we may have multiple copies of the same manifest
# uploaded so fetch one randomly using a randomly generated prefix.
# Generate a random number from 1 to num_file_shards exclusive to use as prefix.
sharded_file_key_prefix = str(random.randrange(1, num_file_shards)) + "/"
# Append prefix (if it exists) to our original file key.
return sharded_file_key_prefix + file_key


def _load_from_s3(data: bytes) -> bytes:
# While many of the baseplate configurations use an ini format,
# we've opted for json in these internal-to-znode-configs because
# we want them to be fully controlled by the writer of the znode
# and json is an easier format for znode authors to work with.
loader_config = json.loads(data.decode("UTF-8"))
try:
num_file_shards = loader_config.get("num_file_shards")

# We expect this key to always be present, otherwise it's an exception.
file_key = loader_config["file_key"]

sharded_file_key = _generate_sharded_file_key(num_file_shards, file_key)

region_name = loader_config["region_name"]
s3_kwargs = {
"Bucket": loader_config["bucket_name"],
"Key": loader_config["file_key"],
"Key": sharded_file_key,
"SSECustomerKey": loader_config["sse_key"],
"SSECustomerAlgorithm": "AES256",
}
Expand All @@ -144,6 +166,8 @@ def _load_from_s3(data: bytes) -> bytes:
# resource is public. In other words, this means that a given service cannot access
# a public resource belonging to another cluster/AWS account unless the request credentials
# are unsigned.

# Default # of retries in legacy mode (current mode) is 5.
s3_client = boto3.client(
"s3",
config=Config(signature_version=UNSIGNED),
Expand Down
74 changes: 66 additions & 8 deletions tests/unit/sidecars/live_data_watcher_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import grp
import json
import logging
import os
import pwd
import tempfile
Expand All @@ -11,8 +12,14 @@

from moto import mock_s3

from baseplate.sidecars.live_data_watcher import _generate_sharded_file_key
from baseplate.sidecars.live_data_watcher import NodeWatcher

NUM_FILE_SHARDS = 6


logger = logging.getLogger(__name__)


class NodeWatcherTests(unittest.TestCase):
mock_s3 = mock_s3()
Expand All @@ -26,13 +33,21 @@ def setUp(self):
region_name="us-east-1",
)
s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(
Bucket=bucket_name,
Key="test_file_key",
Body=json.dumps(s3_data).encode(),
SSECustomerKey="test_decryption_key",
SSECustomerAlgorithm="AES256",
)
default_file_key = "test_file_key"
for file_shard_num in range(NUM_FILE_SHARDS):
if file_shard_num == 0:
# The first copy should just be the original file.
sharded_file_key = default_file_key
else:
# All other copies should include the sharded prefix.
sharded_file_key = str(file_shard_num) + "/" + default_file_key
s3_client.put_object(
Bucket=bucket_name,
Key=sharded_file_key,
Body=json.dumps(s3_data).encode(),
SSECustomerKey="test_decryption_key",
SSECustomerAlgorithm="AES256",
)

def tearDown(self):
self.mock_s3.stop()
Expand All @@ -42,7 +57,33 @@ def run(self, result: unittest.TestResult = None) -> unittest.TestResult:
self.output_dir = Path(loc)
return super().run(result)

def test_s3_load_type_on_change(self):
def test_generate_sharded_file_key_no_sharding(self):
original_file_key = "test_file_key"
expected_sharded_file_key = "test_file_key"
possible_no_sharding_values = [-2, -1, 0, 1, None]
for values in possible_no_sharding_values:
actual_sharded_file_key = _generate_sharded_file_key(values, original_file_key)
self.assertEqual(actual_sharded_file_key, expected_sharded_file_key)

def test_generate_sharded_file_key_sharding(self):
original_file_key = "test_file_key"
possible_sharded_file_keys = set(
[
"1/test_file_key",
"2/test_file_key",
"3/test_file_key",
"4/test_file_key",
"5/test_file_key",
]
)
for i in range(50):
actual_sharded_file_key = _generate_sharded_file_key(NUM_FILE_SHARDS, original_file_key)
# If num_file_shards is provided, the generated file key MUST have a prefix.
self.assertTrue(actual_sharded_file_key in possible_sharded_file_keys)
# Make sure we aren't generating a file without the prefix.
self.assertFalse(actual_sharded_file_key == original_file_key)

def test_s3_load_type_on_change_no_sharding(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)

Expand All @@ -53,6 +94,23 @@ def test_s3_load_type_on_change(self):
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

def test_s3_load_type_on_change_sharding(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)

new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}'
expected_content = b'{"foo_encrypted": "bar_encrypted"}'

# For safe measure, run this 50 times. It should succeed every time.
# We've uploaded 5 files to S3 in setUp() and num_file_shards=5 in the
# ZK node so we should be fetching one of these 5 files randomly (and successfully)
# and all should have the same content.
for i in range(50):
inst.on_change(new_content, None)
self.assertEqual(expected_content, dest.read_bytes())
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

def test_on_change(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)
Expand Down

0 comments on commit 16265f0

Please sign in to comment.