diff --git a/baseplate/sidecars/live_data_watcher.py b/baseplate/sidecars/live_data_watcher.py index 781f3dfeb..acbefb444 100644 --- a/baseplate/sidecars/live_data_watcher.py +++ b/baseplate/sidecars/live_data_watcher.py @@ -4,6 +4,7 @@ import json import logging import os +import random import sys import time @@ -115,6 +116,20 @@ def _parse_loader_type(data: bytes) -> LoaderType: return LoaderType(loader_type) +def _generate_sharded_file_key(num_file_shards: Optional[int], file_key: str) -> str: + # We can't assume that every ZK Node that is being NodeWatched by the live-data-fetcher + # will make use of S3 prefix sharding - but, we know at least one does (/experiments). + # If it's not present or the value is less than 2, set the prefix to empty string "" + sharded_file_key_prefix = "" + if num_file_shards is not None and num_file_shards > 1: + # If the num_file_shards key is present, we may have multiple copies of the same manifest + # uploaded so fetch one randomly using a randomly generated prefix. + # Generate a random number from 1 to num_file_shards exclusive to use as prefix. + sharded_file_key_prefix = str(random.randrange(1, num_file_shards)) + "/" + # Append prefix (if it exists) to our original file key. + return sharded_file_key_prefix + file_key + + def _load_from_s3(data: bytes) -> bytes: # While many of the baseplate configurations use an ini format, # we've opted for json in these internal-to-znode-configs because @@ -122,10 +137,17 @@ def _load_from_s3(data: bytes) -> bytes: # and json is an easier format for znode authors to work with. loader_config = json.loads(data.decode("UTF-8")) try: + num_file_shards = loader_config.get("num_file_shards") + + # We expect this key to always be present, otherwise it's an exception. + file_key = loader_config["file_key"] + + sharded_file_key = _generate_sharded_file_key(num_file_shards, file_key) + region_name = loader_config["region_name"] s3_kwargs = { "Bucket": loader_config["bucket_name"], - "Key": loader_config["file_key"], + "Key": sharded_file_key, "SSECustomerKey": loader_config["sse_key"], "SSECustomerAlgorithm": "AES256", } @@ -144,6 +166,8 @@ def _load_from_s3(data: bytes) -> bytes: # resource is public. In other words, this means that a given service cannot access # a public resource belonging to another cluster/AWS account unless the request credentials # are unsigned. + + # Default # of retries in legacy mode (current mode) is 5. s3_client = boto3.client( "s3", config=Config(signature_version=UNSIGNED), diff --git a/tests/unit/sidecars/live_data_watcher_tests.py b/tests/unit/sidecars/live_data_watcher_tests.py index 0a7916a32..98bf9cc9e 100644 --- a/tests/unit/sidecars/live_data_watcher_tests.py +++ b/tests/unit/sidecars/live_data_watcher_tests.py @@ -1,5 +1,6 @@ import grp import json +import logging import os import pwd import tempfile @@ -11,8 +12,14 @@ from moto import mock_s3 +from baseplate.sidecars.live_data_watcher import _generate_sharded_file_key from baseplate.sidecars.live_data_watcher import NodeWatcher +NUM_FILE_SHARDS = 6 + + +logger = logging.getLogger(__name__) + class NodeWatcherTests(unittest.TestCase): mock_s3 = mock_s3() @@ -26,13 +33,21 @@ def setUp(self): region_name="us-east-1", ) s3_client.create_bucket(Bucket=bucket_name) - s3_client.put_object( - Bucket=bucket_name, - Key="test_file_key", - Body=json.dumps(s3_data).encode(), - SSECustomerKey="test_decryption_key", - SSECustomerAlgorithm="AES256", - ) + default_file_key = "test_file_key" + for file_shard_num in range(NUM_FILE_SHARDS): + if file_shard_num == 0: + # The first copy should just be the original file. + sharded_file_key = default_file_key + else: + # All other copies should include the sharded prefix. + sharded_file_key = str(file_shard_num) + "/" + default_file_key + s3_client.put_object( + Bucket=bucket_name, + Key=sharded_file_key, + Body=json.dumps(s3_data).encode(), + SSECustomerKey="test_decryption_key", + SSECustomerAlgorithm="AES256", + ) def tearDown(self): self.mock_s3.stop() @@ -42,7 +57,33 @@ def run(self, result: unittest.TestResult = None) -> unittest.TestResult: self.output_dir = Path(loc) return super().run(result) - def test_s3_load_type_on_change(self): + def test_generate_sharded_file_key_no_sharding(self): + original_file_key = "test_file_key" + expected_sharded_file_key = "test_file_key" + possible_no_sharding_values = [-2, -1, 0, 1, None] + for values in possible_no_sharding_values: + actual_sharded_file_key = _generate_sharded_file_key(values, original_file_key) + self.assertEqual(actual_sharded_file_key, expected_sharded_file_key) + + def test_generate_sharded_file_key_sharding(self): + original_file_key = "test_file_key" + possible_sharded_file_keys = set( + [ + "1/test_file_key", + "2/test_file_key", + "3/test_file_key", + "4/test_file_key", + "5/test_file_key", + ] + ) + for i in range(50): + actual_sharded_file_key = _generate_sharded_file_key(NUM_FILE_SHARDS, original_file_key) + # If num_file_shards is provided, the generated file key MUST have a prefix. + self.assertTrue(actual_sharded_file_key in possible_sharded_file_keys) + # Make sure we aren't generating a file without the prefix. + self.assertFalse(actual_sharded_file_key == original_file_key) + + def test_s3_load_type_on_change_no_sharding(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777) @@ -53,6 +94,23 @@ def test_s3_load_type_on_change(self): self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name) self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name) + def test_s3_load_type_on_change_sharding(self): + dest = self.output_dir.joinpath("data.txt") + inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777) + + new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}' + expected_content = b'{"foo_encrypted": "bar_encrypted"}' + + # For safe measure, run this 50 times. It should succeed every time. + # We've uploaded 5 files to S3 in setUp() and num_file_shards=5 in the + # ZK node so we should be fetching one of these 5 files randomly (and successfully) + # and all should have the same content. + for i in range(50): + inst.on_change(new_content, None) + self.assertEqual(expected_content, dest.read_bytes()) + self.assertEqual(dest.owner(), pwd.getpwuid(os.getuid()).pw_name) + self.assertEqual(dest.group(), grp.getgrgid(os.getgid()).gr_name) + def test_on_change(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777)