Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[live-data-fetcher] Distribute read load across objects in S3 #817

Merged
merged 30 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions baseplate/sidecars/live_data_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import random
import sys
import time

Expand Down Expand Up @@ -115,17 +116,38 @@ def _parse_loader_type(data: bytes) -> LoaderType:
return LoaderType(loader_type)


def _generate_sharded_file_key(num_file_shards: Optional[int], file_key: str) -> str:
# We can't assume that every ZK Node that is being NodeWatched by the live-data-fetcher
# will make use of S3 prefix sharding - but, we know at least one does (/experiments).
# If it's not present or the value is less than 2, set the prefix to empty string ""
sharded_file_key_prefix = ""
if num_file_shards is not None and num_file_shards > 1:
# If the num_file_shards key is present, we may have multiple copies of the same manifest
# uploaded so fetch one randomly using a randomly generated prefix.
# Generate a random number from 1 to num_file_shards exclusive to use as prefix.
sharded_file_key_prefix = str(random.randrange(1, num_file_shards)) + "/"
SiddharthManoj marked this conversation as resolved.
Show resolved Hide resolved
# Append prefix (if it exists) to our original file key.
return sharded_file_key_prefix + file_key


def _load_from_s3(data: bytes) -> bytes:
# While many of the baseplate configurations use an ini format,
# we've opted for json in these internal-to-znode-configs because
# we want them to be fully controlled by the writer of the znode
# and json is an easier format for znode authors to work with.
loader_config = json.loads(data.decode("UTF-8"))
try:
num_file_shards = loader_config.get("num_file_shards")
SiddharthManoj marked this conversation as resolved.
Show resolved Hide resolved

# We expect this key to always be present, otherwise it's an exception.
file_key = loader_config["file_key"]

sharded_file_key = _generate_sharded_file_key(num_file_shards, file_key)

region_name = loader_config["region_name"]
s3_kwargs = {
"Bucket": loader_config["bucket_name"],
"Key": loader_config["file_key"],
"Key": sharded_file_key,
"SSECustomerKey": loader_config["sse_key"],
"SSECustomerAlgorithm": "AES256",
}
Expand All @@ -145,16 +167,15 @@ def _load_from_s3(data: bytes) -> bytes:
# a public resource belonging to another cluster/AWS account unless the request credentials
# are unsigned.

# Access S3 with 10 max retries enabled:
# Default # of retries in legacy mode (current mode) is 5.
s3_client = boto3.client(
"s3",
config=Config(signature_version=UNSIGNED, retries={"total_max_attempts": 10}),
config=Config(signature_version=UNSIGNED),
region_name=region_name,
)
Comment on lines +170 to 175
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per https://reddit.slack.com/archives/C067Q6N14LA/p1701373678708859?thread_ts=1701370864.667409&cid=C067Q6N14LA we want to revert the changes made in PR/813 where we increased retry total_max_attempt from 5->10

cc @areitz @jerroydmoore

else:
s3_client = boto3.client(
"s3",
config=Config(retries={"total_max_attempts": 10}),
region_name=region_name,
)

Expand Down
74 changes: 66 additions & 8 deletions tests/unit/sidecars/live_data_watcher_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import grp
import json
import logging
import os
import pwd
import tempfile
Expand All @@ -11,8 +12,14 @@

from moto import mock_s3

from baseplate.sidecars.live_data_watcher import _generate_sharded_file_key
Copy link
Contributor Author

@SiddharthManoj SiddharthManoj Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know _generate_sharded_file_key has an underscore prefix, but we're only using it internally in reality.

I just need to call it here as part of the unit test

from baseplate.sidecars.live_data_watcher import NodeWatcher

NUM_FILE_SHARDS = 6


logger = logging.getLogger(__name__)


class NodeWatcherTests(unittest.TestCase):
mock_s3 = mock_s3()
Expand All @@ -26,13 +33,21 @@ def setUp(self):
region_name="us-east-1",
)
s3_client.create_bucket(Bucket=bucket_name)
s3_client.put_object(
Bucket=bucket_name,
Key="test_file_key",
Body=json.dumps(s3_data).encode(),
SSECustomerKey="test_decryption_key",
SSECustomerAlgorithm="AES256",
)
default_file_key = "test_file_key"
for file_shard_num in range(NUM_FILE_SHARDS):
if file_shard_num == 0:
# The first copy should just be the original file.
sharded_file_key = default_file_key
else:
# All other copies should include the sharded prefix.
sharded_file_key = str(file_shard_num) + "/" + default_file_key
s3_client.put_object(
Bucket=bucket_name,
Key=sharded_file_key,
Body=json.dumps(s3_data).encode(),
SSECustomerKey="test_decryption_key",
SSECustomerAlgorithm="AES256",
)
Comment on lines +36 to +50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unit test, we can upload all the copies of the file in the setUp.

And then in each unit test, we can optionally provide the num_file_shards key in the ZK content. If the key is missing, we should expect the unit test to still pass since it should fetch the original file with no prefix.

If the key is present, we should also expect it to pass by fetching one of the prefixed files.


def tearDown(self):
self.mock_s3.stop()
Expand All @@ -42,7 +57,33 @@ def run(self, result: unittest.TestResult = None) -> unittest.TestResult:
self.output_dir = Path(loc)
return super().run(result)

def test_s3_load_type_on_change(self):
def test_generate_sharded_file_key_no_sharding(self):
original_file_key = "test_file_key"
expected_sharded_file_key = "test_file_key"
possible_no_sharding_values = [-2, -1, 0, 1, None]
for values in possible_no_sharding_values:
actual_sharded_file_key = _generate_sharded_file_key(values, original_file_key)
self.assertEqual(actual_sharded_file_key, expected_sharded_file_key)

def test_generate_sharded_file_key_sharding(self):
original_file_key = "test_file_key"
possible_sharded_file_keys = set(
[
"1/test_file_key",
"2/test_file_key",
"3/test_file_key",
"4/test_file_key",
"5/test_file_key",
]
)
for i in range(50):
actual_sharded_file_key = _generate_sharded_file_key(NUM_FILE_SHARDS, original_file_key)
# If num_file_shards is provided, the generated file key MUST have a prefix.
self.assertTrue(actual_sharded_file_key in possible_sharded_file_keys)
# Make sure we aren't generating a file without the prefix.
self.assertFalse(actual_sharded_file_key == original_file_key)

def test_s3_load_type_on_change_no_sharding(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)

Expand All @@ -53,6 +94,23 @@ def test_s3_load_type_on_change(self):
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

def test_s3_load_type_on_change_sharding(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)

new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}'
expected_content = b'{"foo_encrypted": "bar_encrypted"}'

# For safe measure, run this 50 times. It should succeed every time.
# We've uploaded 5 files to S3 in setUp() and num_file_shards=5 in the
SiddharthManoj marked this conversation as resolved.
Show resolved Hide resolved
# ZK node so we should be fetching one of these 5 files randomly (and successfully)
# and all should have the same content.
for i in range(50):
inst.on_change(new_content, None)
self.assertEqual(expected_content, dest.read_bytes())
self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name)
self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name)

def test_on_change(self):
dest = self.output_dir.joinpath("data.txt")
inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)
Expand Down
Loading