diff --git a/src/deadline/client/api/_submit_job_bundle.py b/src/deadline/client/api/_submit_job_bundle.py index dd419112..4541e8b9 100644 --- a/src/deadline/client/api/_submit_job_bundle.py +++ b/src/deadline/client/api/_submit_job_bundle.py @@ -17,7 +17,7 @@ from .. import api from ..exceptions import DeadlineOperationError, CreateJobWaiterCanceled -from ..config import get_setting, set_setting +from ..config import get_setting, set_setting, config_file from ..job_bundle import deadline_yaml_dump from ..job_bundle.loader import ( read_yaml_or_json, @@ -378,7 +378,7 @@ def _default_update_hash_progress(hashing_metadata: Dict[str, str]) -> bool: asset_groups=asset_groups, total_input_files=total_input_files, total_input_bytes=total_input_bytes, - hash_cache_dir=os.path.expanduser(os.path.join("~", ".deadline", "cache")), + hash_cache_dir=config_file.get_cache_directory(), on_preparing_to_submit=hashing_progress_callback, ) api.get_deadline_cloud_library_telemetry_client(config=config).record_hashing_summary( @@ -409,7 +409,9 @@ def _default_update_upload_progress(upload_metadata: Dict[str, str]) -> bool: upload_progress_callback = _default_update_upload_progress upload_summary, attachment_settings = asset_manager.upload_assets( - manifests, upload_progress_callback + manifests=manifests, + on_uploading_assets=upload_progress_callback, + s3_check_cache_dir=config_file.get_cache_directory(), ) api.get_deadline_cloud_library_telemetry_client(config=config).record_upload_summary( upload_summary diff --git a/src/deadline/client/config/config_file.py b/src/deadline/client/config/config_file.py index 6f486fb3..44edfed5 100644 --- a/src/deadline/client/config/config_file.py +++ b/src/deadline/client/config/config_file.py @@ -2,6 +2,7 @@ __all__ = [ "get_config_file_path", + "get_cache_directory", "read_config", "write_config", "get_setting_default", @@ -38,6 +39,7 @@ # The default directory within which to save the history of created jobs. DEFAULT_JOB_HISTORY_DIR = os.path.join("~", ".deadline", "job_history", "{aws_profile_name}") +DEFAULT_CACHE_DIR = os.path.join("~", ".deadline", "cache") _TRUE_VALUES = {"yes", "on", "true", "1"} _FALSE_VALUES = {"no", "off", "false", "0"} @@ -107,10 +109,6 @@ "telemetry.opt_out": {"default": "false"}, "telemetry.identifier": {"default": ""}, "defaults.job_attachments_file_system": {"default": "COPIED", "depend": "defaults.farm_id"}, - "settings.list_object_threshold": { - "default": "100", - "description": "If the number of files to be uploaded are bigger than this threshold, it switches to call list-objects S3 API from head-object call to check if files have already been uploaded.", - }, "settings.multipart_upload_chunk_size": { "default": "8388608", # 8 MB (Default chunk size for multipart upload) "description": "The chunk size to use when uploading files in multi-parts.", @@ -137,6 +135,13 @@ def get_config_file_path() -> Path: return Path(os.environ.get(CONFIG_FILE_PATH_ENV_VAR) or CONFIG_FILE_PATH).expanduser() +def get_cache_directory() -> str: + """ + Get the cache directory. + """ + return os.path.expanduser(DEFAULT_CACHE_DIR) + + def _should_read_config(config_file_path: Path) -> bool: global __config_file_path global __config_mtime diff --git a/src/deadline/client/ui/dialogs/submit_job_progress_dialog.py b/src/deadline/client/ui/dialogs/submit_job_progress_dialog.py index 145929a1..f40ff6d6 100644 --- a/src/deadline/client/ui/dialogs/submit_job_progress_dialog.py +++ b/src/deadline/client/ui/dialogs/submit_job_progress_dialog.py @@ -300,7 +300,7 @@ def _update_hash_progress(hashing_metadata: ProgressReportMetadata) -> bool: asset_groups=asset_groups, total_input_files=total_input_files, total_input_bytes=total_input_bytes, - hash_cache_dir=os.path.expanduser(os.path.join("~", ".deadline", "cache")), + hash_cache_dir=config_file.get_cache_directory(), on_preparing_to_submit=_update_hash_progress, ) @@ -334,8 +334,9 @@ def _update_upload_progress(upload_metadata: ProgressReportMetadata) -> bool: upload_summary, attachment_settings = cast( S3AssetManager, self._asset_manager ).upload_assets( - manifests, - _update_upload_progress, + manifests=manifests, + on_uploading_assets=_update_upload_progress, + s3_check_cache_dir=config_file.get_cache_directory(), ) logger.info("Finished uploading job attachments files.") diff --git a/src/deadline/job_attachments/_utils.py b/src/deadline/job_attachments/_utils.py index 612e4a3c..7ade3528 100644 --- a/src/deadline/job_attachments/_utils.py +++ b/src/deadline/job_attachments/_utils.py @@ -1,10 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import datetime -import os from hashlib import shake_256 from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Tuple, Union import uuid @@ -15,13 +14,9 @@ "_human_readable_file_size", "_get_unique_dest_dir_name", "_get_bucket_and_object_key", - "_get_default_hash_cache_db_file_dir", "_is_relative_to", ] -CONFIG_ROOT = ".deadline" -COMPONENT_NAME = "job_attachments" - def _join_s3_paths(root: str, *args: str): return "/".join([root, *args]) @@ -81,17 +76,6 @@ def _get_bucket_and_object_key(s3_path: str) -> Tuple[str, str]: return bucket, key -def _get_default_hash_cache_db_file_dir() -> Optional[str]: - """ - Gets the expected directory for the hash cache database file based on OS environment variables. - If a directory cannot be found, defaults to the working directory. - """ - default_path = os.environ.get("HOME") - if default_path: - default_path = os.path.join(default_path, CONFIG_ROOT, COMPONENT_NAME) - return default_path - - def _is_relative_to(path1: Union[Path, str], path2: Union[Path, str]) -> bool: """ Determines if path1 is relative to path2. This function is to support diff --git a/src/deadline/job_attachments/asset_sync.py b/src/deadline/job_attachments/asset_sync.py index dbf8e619..eb8f1b06 100644 --- a/src/deadline/job_attachments/asset_sync.py +++ b/src/deadline/job_attachments/asset_sync.py @@ -41,7 +41,7 @@ ) from .fus3 import Fus3ProcessManager -from .exceptions import AssetSyncError, Fus3ExecutableMissingError +from .exceptions import AssetSyncError, Fus3ExecutableMissingError, JobAttachmentsS3ClientError from .models import ( Attachments, JobAttachmentsFileSystem, @@ -426,15 +426,31 @@ def sync_inputs( f"Virtual File System not found, falling back to {JobAttachmentsFileSystem.COPIED} for JobAttachmentsFileSystem." ) - download_summary_statistics = download_files_from_manifests( - s3_bucket=s3_settings.s3BucketName, - manifests_by_root=merged_manifests_by_root, - cas_prefix=s3_settings.full_cas_prefix(), - fs_permission_settings=fs_permission_settings, - session=self.session, - on_downloading_files=on_downloading_files, - logger=self.logger, - ) + try: + download_summary_statistics = download_files_from_manifests( + s3_bucket=s3_settings.s3BucketName, + manifests_by_root=merged_manifests_by_root, + cas_prefix=s3_settings.full_cas_prefix(), + fs_permission_settings=fs_permission_settings, + session=self.session, + on_downloading_files=on_downloading_files, + logger=self.logger, + ) + except JobAttachmentsS3ClientError as exc: + if exc.status_code == 404: + raise JobAttachmentsS3ClientError( + action=exc.action, + status_code=exc.status_code, + bucket_name=exc.bucket_name, + key_or_prefix=exc.key_or_prefix, + message=( + "This can happen if the S3 check cache on the submitting machine is out of date. " + "Please delete the cache file from the submitting machine, usually located in the " + "home directory (~/.deadline/cache/s3_check_cache.db) and try submitting again." + ), + ) from exc + else: + raise return ( download_summary_statistics.convert_to_summary_statistics(), diff --git a/src/deadline/job_attachments/caches/__init__.py b/src/deadline/job_attachments/caches/__init__.py new file mode 100644 index 00000000..7ee77a44 --- /dev/null +++ b/src/deadline/job_attachments/caches/__init__.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .cache_db import CacheDB, CONFIG_ROOT, COMPONENT_NAME +from .hash_cache import HashCache, HashCacheEntry +from .s3_check_cache import S3CheckCache, S3CheckCacheEntry + +__all__ = [ + "CacheDB", + "CONFIG_ROOT", + "COMPONENT_NAME", + "HashCache", + "HashCacheEntry", + "S3CheckCache", + "S3CheckCacheEntry", +] diff --git a/src/deadline/job_attachments/caches/cache_db.py b/src/deadline/job_attachments/caches/cache_db.py new file mode 100644 index 00000000..73a73aa4 --- /dev/null +++ b/src/deadline/job_attachments/caches/cache_db.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +""" +Module for defining a local cache file. +""" + +import logging +import os +from abc import ABC +from threading import Lock +from typing import Optional + +from ..exceptions import JobAttachmentsError + +CONFIG_ROOT = ".deadline" +COMPONENT_NAME = "job_attachments" + +logger = logging.getLogger("Deadline") + + +class CacheDB(ABC): + """ + Abstract base class for connecting to a local SQLite cache database. + + This class is intended to always be used with a context manager to properly + close the connection to the cache database. + """ + + def __init__( + self, cache_name: str, table_name: str, create_query: str, cache_dir: Optional[str] = None + ) -> None: + if not cache_name or not table_name or not create_query: + raise JobAttachmentsError("Constructor strings for CacheDB cannot be empty.") + self.cache_name: str = cache_name + self.table_name: str = table_name + self.create_query: str = create_query + + try: + # SQLite is included in Python installers, but might not exist if building python from source. + import sqlite3 # noqa + + self.enabled = True + except ImportError: + logger.warn(f"SQLite was not found, {cache_name} will not be used.") + self.enabled = False + return + + if cache_dir is None: + cache_dir = self.get_default_cache_db_file_dir() + if cache_dir is None: + raise JobAttachmentsError( + f"No default cache path found. Please provide a directory for {self.cache_name}." + ) + os.makedirs(cache_dir, exist_ok=True) + self.cache_dir: str = os.path.join(cache_dir, f"{self.cache_name}.db") + self.db_lock = Lock() + + def __enter__(self): + """Called when entering the context manager.""" + if self.enabled: + import sqlite3 + + try: + self.db_connection: sqlite3.Connection = sqlite3.connect( + self.cache_dir, check_same_thread=False + ) + except sqlite3.OperationalError as oe: + raise JobAttachmentsError( + f"Could not access cache file in {self.cache_dir}" + ) from oe + + try: + self.db_connection.execute(f"SELECT * FROM {self.table_name}") + except Exception: + # DB file doesn't have our table, so we need to create it + logger.info( + f"No cache entries for the current library version were found. Creating a new cache for {self.cache_name}" + ) + self.db_connection.execute(self.create_query) + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Called when exiting the context manager.""" + if self.enabled: + self.db_connection.close() + + @classmethod + def get_default_cache_db_file_dir(cls) -> Optional[str]: + """ + Gets the expected directory for the cache database file based on OS environment variables. + If a directory cannot be found, defaults to the working directory. + """ + default_path = os.environ.get("HOME") + if default_path: + default_path = os.path.join(default_path, CONFIG_ROOT, COMPONENT_NAME) + return default_path diff --git a/src/deadline/job_attachments/caches/hash_cache.py b/src/deadline/job_attachments/caches/hash_cache.py new file mode 100644 index 00000000..996f0118 --- /dev/null +++ b/src/deadline/job_attachments/caches/hash_cache.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +""" +Module for accessing the local file hash cache. +""" + +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from .cache_db import CacheDB +from ..asset_manifests.hash_algorithms import HashAlgorithm + + +logger = logging.getLogger("Deadline") + + +@dataclass +class HashCacheEntry: + """Represents an entry in the local hash-cache database""" + + file_path: str + hash_algorithm: HashAlgorithm + file_hash: str + last_modified_time: str + + def to_dict(self) -> Dict[str, Any]: + return { + "file_path": self.file_path, + "hash_algorithm": self.hash_algorithm.value, + "file_hash": self.file_hash, + "last_modified_time": self.last_modified_time, + } + + +class HashCache(CacheDB): + """ + Class used to store and retrieve entries in the local file hash cache. + + This class is intended to always be used with a context manager to properly + close the connection to the hash cache database. + + This class also automatically locks when doing writes, so it can be called + by multiple threads. + """ + + CACHE_NAME = "hash_cache" + CACHE_DB_VERSION = 2 + + def __init__(self, cache_dir: Optional[str] = None) -> None: + table_name: str = f"hashesV{self.CACHE_DB_VERSION}" + create_query: str = f"CREATE TABLE hashesV{self.CACHE_DB_VERSION}(file_path text primary key, hash_algorithm text secondary key, file_hash text, last_modified_time timestamp)" + super().__init__( + cache_name=self.CACHE_NAME, + table_name=table_name, + create_query=create_query, + cache_dir=cache_dir, + ) + + def get_entry( + self, file_path_key: str, hash_algorithm: HashAlgorithm + ) -> Optional[HashCacheEntry]: + """ + Returns an entry from the hash cache, if it exists. + """ + if not self.enabled: + return None + + with self.db_lock, self.db_connection: + entry_vals = self.db_connection.execute( + f"SELECT * FROM {self.table_name} WHERE file_path=? AND hash_algorithm=?", + [file_path_key, hash_algorithm.value], + ).fetchone() + if entry_vals: + return HashCacheEntry( + file_path=entry_vals[0], + hash_algorithm=HashAlgorithm(entry_vals[1]), + file_hash=entry_vals[2], + last_modified_time=str(entry_vals[3]), + ) + else: + return None + + def put_entry(self, entry: HashCacheEntry) -> None: + """Inserts or replaces an entry into the hash cache database after acquiring the lock.""" + if self.enabled: + with self.db_lock, self.db_connection: + self.db_connection.execute( + f"INSERT OR REPLACE INTO {self.table_name} VALUES(:file_path, :hash_algorithm, :file_hash, :last_modified_time)", + entry.to_dict(), + ) diff --git a/src/deadline/job_attachments/caches/s3_check_cache.py b/src/deadline/job_attachments/caches/s3_check_cache.py new file mode 100644 index 00000000..fdb9b87f --- /dev/null +++ b/src/deadline/job_attachments/caches/s3_check_cache.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +""" +Module for accessing the local 'last seen on S3' cache. +""" + +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, Optional + +from .cache_db import CacheDB + + +logger = logging.getLogger("Deadline") + + +@dataclass +class S3CheckCacheEntry: + """Represents an entry in the local s3 check cache database""" + + s3_key: str + last_seen_time: str + + def to_dict(self) -> Dict[str, Any]: + return { + "s3_key": self.s3_key, + "last_seen_time": self.last_seen_time, + } + + +class S3CheckCache(CacheDB): + """ + Maintains a cache of 'last seen on S3' entries in a local database, which + specifies which full S3 object keys exist in the content-addressed storage + in the Job Attachments S3 bucket. + + This class is intended to always be used with a context manager to properly + close the connection to the hash cache database. + + This class also automatically locks when doing writes, so it can be called + by multiple threads. + """ + + CACHE_NAME = "s3_check_cache" + CACHE_DB_VERSION = 1 + ENTRY_EXPIRY_DAYS = 30 + + def __init__(self, cache_dir: Optional[str] = None) -> None: + table_name: str = f"s3checkV{self.CACHE_DB_VERSION}" + create_query: str = f"CREATE TABLE s3checkV{self.CACHE_DB_VERSION}(s3_key text primary key, last_seen_time timestamp)" + super().__init__( + cache_name=self.CACHE_NAME, + table_name=table_name, + create_query=create_query, + cache_dir=cache_dir, + ) + + def get_entry(self, s3_key: str) -> Optional[S3CheckCacheEntry]: + """ + Checks if an entry exists in the cache, and returns it if it hasn't expired. + """ + if not self.enabled: + return None + + with self.db_lock, self.db_connection: + entry_vals = self.db_connection.execute( + f"SELECT * FROM {self.table_name} WHERE s3_key=?", + [s3_key], + ).fetchone() + if entry_vals: + entry = S3CheckCacheEntry( + s3_key=entry_vals[0], + last_seen_time=str(entry_vals[1]), + ) + try: + last_seen = datetime.fromtimestamp(float(entry.last_seen_time)) + if (datetime.now() - last_seen).days < self.ENTRY_EXPIRY_DAYS: + return entry + except ValueError: + logger.warning(f"Timestamp for S3 key {s3_key} is not valid. Ignoring.") + + return None + + def put_entry(self, entry: S3CheckCacheEntry) -> None: + """Inserts or replaces an entry into the cache database.""" + if self.enabled: + with self.db_lock, self.db_connection: + self.db_connection.execute( + f"INSERT OR REPLACE INTO {self.table_name} VALUES(:s3_key, :last_seen_time)", + entry.to_dict(), + ) diff --git a/src/deadline/job_attachments/download.py b/src/deadline/job_attachments/download.py index f50364d6..557ed2a2 100644 --- a/src/deadline/job_attachments/download.py +++ b/src/deadline/job_attachments/download.py @@ -75,7 +75,13 @@ def get_manifest_from_s3( **COMMON_ERROR_GUIDANCE_FOR_S3, 403: ( "Forbidden or Access denied. Please check your AWS credentials, and ensure that " - "your AWS IAM Role or User has the 's3:GetObject' permission for this bucket." + "your AWS IAM Role or User has the 's3:GetObject' permission for this bucket. " + ) + if "kms:" not in str(exc) + else ( + "Forbidden or Access denied. Please check your AWS credentials and Job Attachments S3 bucket " + "encryption settings. If a customer-managed KMS key is set, confirm that your AWS IAM Role or " + "User has the 'kms:Decrypt' and 'kms:DescribeKey' permissions for the key used to encrypt the bucket." ), 404: "Not found. Please check your bucket name and object key, and ensure that they exist in the AWS account.", } @@ -397,10 +403,18 @@ def process_client_error(exc: ClientError, status_code: int): status_code_guidance = { **COMMON_ERROR_GUIDANCE_FOR_S3, 403: ( - "Forbidden or Access denied. Please check your AWS credentials, or ensure that " - "your AWS IAM Role or User has the 's3:GetObject' permission for this bucket." + "Forbidden or Access denied. Please check your AWS credentials, and ensure that " + "your AWS IAM Role or User has the 's3:GetObject' permission for this bucket. " + ) + if "kms:" not in str(exc) + else ( + "Forbidden or Access denied. Please check your AWS credentials and Job Attachments S3 bucket " + "encryption settings. If a customer-managed KMS key is set, confirm that your AWS IAM Role or " + "User has the 'kms:Decrypt' and 'kms:DescribeKey' permissions for the key used to encrypt the bucket." + ), + 404: ( + "Not found. Please check your bucket name and object key, and ensure that they exist in the AWS account." ), - 404: "Not found. Please check your bucket name and object key, and ensure that they exist in the AWS account.", } raise JobAttachmentsS3ClientError( action="downloading file", diff --git a/src/deadline/job_attachments/hash_cache.py b/src/deadline/job_attachments/hash_cache.py deleted file mode 100644 index fc7a1007..00000000 --- a/src/deadline/job_attachments/hash_cache.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -""" -Module for accessing the local file hash cache. -""" - -import logging -import os -from threading import Lock -from typing import Optional - -from .asset_manifests.hash_algorithms import HashAlgorithm -from .exceptions import JobAttachmentsError -from .models import HashCacheEntry -from ._utils import _get_default_hash_cache_db_file_dir - -CACHE_FILE_NAME = "hash_cache.db" -CACHE_DB_VERSION = 2 - -logger = logging.getLogger("Deadline") - - -class HashCache: - """ - Class used to store and retrieve entries in the local file hash cache. - - This class is intended to always be used with a context manager to properly - close the connection to the hash cache database. - - This class also automatically locks when doing writes, so it can be called - by multiple threads. - """ - - def __init__(self, cache_dir: Optional[str] = None) -> None: - try: - # SQLite is included in Python installers, but might not exist if building python from source. - import sqlite3 # noqa - - self.enabled = True - except ImportError: - logger.warn("SQLite was not found, the Hash Cache will not be used.") - self.enabled = False - return - - if cache_dir is None: - cache_dir = _get_default_hash_cache_db_file_dir() - if cache_dir is None: - raise JobAttachmentsError( - "No default hash cache path found. Please provide a hash cache directory." - ) - os.makedirs(cache_dir, exist_ok=True) - self.cache_dir: str = os.path.join(cache_dir, CACHE_FILE_NAME) - self.db_lock = Lock() - - def __enter__(self): - """Called when entering the context manager.""" - if self.enabled: - import sqlite3 - - try: - self.db_connection: sqlite3.Connection = sqlite3.connect( - self.cache_dir, check_same_thread=False - ) - except sqlite3.OperationalError as oe: - raise JobAttachmentsError( - f"Could not access hash cache file in {self.cache_dir}" - ) from oe - - try: - self.db_connection.execute(f"SELECT * FROM hashesV{CACHE_DB_VERSION}") - except Exception: - # DB file doesn't have our table, so we need to create it - logger.info( - "No hash cache entries for the current library version were found. Creating a new hash cache." - ) - self.db_connection.execute( - f"CREATE TABLE hashesV{CACHE_DB_VERSION}(file_path text primary key, hash_algorithm text secondary key, file_hash text, last_modified_time timestamp)" - ) - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - """Called when exiting the context manager.""" - if self.enabled: - self.db_connection.close() - - def get_entry( - self, file_path_key: str, hash_algorithm: HashAlgorithm - ) -> Optional[HashCacheEntry]: - """ - Returns an entry from the hash cache, if it exists. - """ - if not self.enabled: - return None - - with self.db_lock, self.db_connection: - entry_vals = self.db_connection.execute( - f"SELECT * FROM hashesV{CACHE_DB_VERSION} WHERE file_path=? AND hash_algorithm=?", - [file_path_key, hash_algorithm.value], - ).fetchone() - if entry_vals: - return HashCacheEntry( - file_path=entry_vals[0], - hash_algorithm=HashAlgorithm(entry_vals[1]), - file_hash=entry_vals[2], - last_modified_time=str(entry_vals[3]), - ) - else: - return None - - def put_entry(self, entry: HashCacheEntry) -> None: - """Inserts or replaces an entry into the hash cache database after acquiring the lock.""" - if self.enabled: - with self.db_lock, self.db_connection: - self.db_connection.execute( - f"INSERT OR REPLACE INTO hashesV{CACHE_DB_VERSION} VALUES(:file_path, :hash_algorithm, :file_hash, :last_modified_time)", - entry.to_dict(), - ) diff --git a/src/deadline/job_attachments/models.py b/src/deadline/job_attachments/models.py index ad4da4a6..e8df937c 100644 --- a/src/deadline/job_attachments/models.py +++ b/src/deadline/job_attachments/models.py @@ -98,24 +98,6 @@ def get_all_paths(self) -> list[str]: return sorted(path_list) -@dataclass -class HashCacheEntry: - """Represents an entry in the local hash-cache database""" - - file_path: str - hash_algorithm: HashAlgorithm - file_hash: str - last_modified_time: str - - def to_dict(self) -> dict[str, Any]: - return { - "file_path": self.file_path, - "hash_algorithm": self.hash_algorithm.value, - "file_hash": self.file_hash, - "last_modified_time": self.last_modified_time, - } - - @dataclass class OutputFile: """Files for output""" diff --git a/src/deadline/job_attachments/upload.py b/src/deadline/job_attachments/upload.py index 9bc4b262..bc641f92 100644 --- a/src/deadline/job_attachments/upload.py +++ b/src/deadline/job_attachments/upload.py @@ -43,14 +43,13 @@ MissingS3BucketError, MissingS3RootPrefixError, ) -from .hash_cache import HashCache +from .caches import HashCache, HashCacheEntry, S3CheckCache, S3CheckCacheEntry from .models import ( AssetRootGroup, AssetRootManifest, AssetUploadGroup, Attachments, FileSystemLocationType, - HashCacheEntry, JobAttachmentS3Settings, ManifestProperties, PathFormat, @@ -67,6 +66,11 @@ logger = logging.getLogger("deadline.job_attachments.upload") +# TODO: tune this. max_worker defaults to 5 * number of processors. We can run into issues here +# if we thread too aggressively on slower internet connections. So for now let's set it to 5, +# which would the number of threads with one processor. +NUM_UPLOAD_WORKERS: int = 5 + class S3AssetUploader: """ @@ -88,9 +92,6 @@ def __init__( # TODO: full performance analysis to determine the ideal thresholds try: - self.list_object_threshold = int( - config_file.get_setting("settings.list_object_threshold") - ) self.multipart_upload_chunk_size = int( config_file.get_setting("settings.multipart_upload_chunk_size") ) @@ -107,17 +108,12 @@ def __init__( except ValueError as ve: raise AssetSyncError( "Failed to parse configuration settings. Please ensure that the following settings in the config file are integers: " - "list_object_threshold, multipart_upload_chunk_size, multipart_upload_max_workers, small_file_threshold_multiplier" + "multipart_upload_chunk_size, multipart_upload_max_workers, small_file_threshold_multiplier" ) from ve # Confirm that the settings values are all positive. error_msg = "" - if self.list_object_threshold <= 0: - error_msg = ( - f"list_object_threshold ({self.list_object_threshold}) must be positive integer." - ) - - elif self.multipart_upload_chunk_size <= 0: + if self.multipart_upload_chunk_size <= 0: error_msg = f"multipart_upload_chunk_size ({self.multipart_upload_chunk_size}) must be positive integer." elif self.multipart_upload_max_workers <= 0: @@ -137,6 +133,7 @@ def upload_assets( source_root: Path, file_system_location_name: Optional[str] = None, progress_tracker: Optional[ProgressTracker] = None, + s3_check_cache_dir: Optional[str] = None, ) -> tuple[str, str]: """ Uploads assets based off of an asset manifest, uploads the asset manifest. @@ -159,6 +156,7 @@ def upload_assets( source_root, job_attachment_settings.full_cas_prefix(), progress_tracker, + s3_check_cache_dir, ) hash_alg = manifest.get_default_hash_alg() manifest_bytes = manifest.encode().encode("utf-8") @@ -192,96 +190,60 @@ def upload_input_files( source_root: Path, s3_cas_prefix: str, progress_tracker: Optional[ProgressTracker] = None, + s3_check_cache_dir: Optional[str] = None, ) -> None: """ Uploads all of the files listed in the given manifest to S3 if they don't exist in the given S3 prefix already. - Depending on the number of files to be uploaded, will either make a head-object or list-objects - S3 API call to check if files have already been uploaded. Note that head-object is cheaper - to call, but slows down significantly if needing to call many times, so the list-objects API - is called for larger file lists. - - TODO: There is a known performance bottleneck if the bucket has a large number of files, but - there isn't currently any way of knowing the size of the bucket without iterating through the - contents of a prefix. For now, we'll just head-object when we have a small number of files. + The local 'S3 check cache' is used to note if we've seen an object in S3 before so we + can save the S3 API calls. """ - files_to_upload: list[base_manifest.BaseManifestPath] = manifest.paths - check_if_in_s3 = True - - if len(files_to_upload) >= self.list_object_threshold: - # If different files have the same content (and thus the same hash), they are counted as skipped files. - file_dict: dict[str, base_manifest.BaseManifestPath] = {} - for file in files_to_upload: - # TODO: replace with uncommented line below after sufficient time after the next release - file_key = f"{file.hash}" # .{manifest.hashAlg.value}" - if file_key in file_dict and progress_tracker: - progress_tracker.increase_skipped( - 1, (source_root.joinpath(file.path)).stat().st_size - ) - else: - file_dict[file_key] = file - - to_upload_set: set[str] = self.filter_objects_to_upload( - s3_bucket, s3_cas_prefix, set(file_dict.keys()) - ) - files_to_upload = [file_dict[k] for k in to_upload_set] - check_if_in_s3 = False # Can skip the check since we just did it above - # The input files that are already in s3 are counted as skipped files. - if progress_tracker: - skipped_set = set(file_dict.keys()) - to_upload_set - files_to_skip = [file_dict[k] for k in skipped_set] - progress_tracker.increase_skipped( - len(files_to_skip), - sum((source_root.joinpath(file.path)).stat().st_size for file in files_to_skip), - ) # Split into a separate 'large file' and 'small file' queues. # Separate 'large' files from 'small' files so that we can process 'large' files serially. # This wastes less bandwidth if uploads are cancelled, as it's better to use the multi-threaded # multi-part upload for a single large file than multiple large files at the same time. (small_file_queue, large_file_queue) = self._separate_files_by_size( - files_to_upload, self.small_file_threshold + manifest.paths, self.small_file_threshold ) - # First, process the whole 'small file' queue with parallel object uploads. - # TODO: tune this. max_worker defaults to 5 * number of processors. We can run into issues here - # if we thread too aggressively on slower internet connections. So for now let's set it to 5, - # which would the number of threads with one processor. - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = { - executor.submit( - self.upload_object_to_cas, + with S3CheckCache(s3_check_cache_dir) as s3_cache: + # First, process the whole 'small file' queue with parallel object uploads. + with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_UPLOAD_WORKERS) as executor: + futures = { + executor.submit( + self.upload_object_to_cas, + file, + manifest.hashAlg, + s3_bucket, + source_root, + s3_cas_prefix, + s3_cache, + progress_tracker, + ): file + for file in small_file_queue + } + # surfaces any exceptions in the thread + for future in concurrent.futures.as_completed(futures): + (is_uploaded, file_size) = future.result() + if progress_tracker and not is_uploaded: + progress_tracker.increase_skipped(1, file_size) + + # Now process the whole 'large file' queue with serial object uploads (but still parallel multi-part upload.) + for file in large_file_queue: + (is_uploaded, file_size) = self.upload_object_to_cas( file, manifest.hashAlg, s3_bucket, source_root, s3_cas_prefix, - check_if_in_s3, + s3_cache, progress_tracker, - ): file - for file in small_file_queue - } - # surfaces any exceptions in the thread - for future in concurrent.futures.as_completed(futures): - (is_uploaded, file_size) = future.result() + ) if progress_tracker and not is_uploaded: progress_tracker.increase_skipped(1, file_size) - # Now process the whole 'large file' queue with serial object uploads (but still parallel multi-part upload.) - for file in large_file_queue: - (is_uploaded, file_size) = self.upload_object_to_cas( - file, - manifest.hashAlg, - s3_bucket, - source_root, - s3_cas_prefix, - check_if_in_s3, - progress_tracker, - ) - if progress_tracker and not is_uploaded: - progress_tracker.increase_skipped(1, file_size) - # to report progress 100% at the end, and # to check if the job submission was canceled in the middle of processing the last batch of files. if progress_tracker: @@ -308,6 +270,9 @@ def _separate_files_by_size( large_file_queue.append(file) return (small_file_queue, large_file_queue) + def _get_current_timestamp(self) -> str: + return str(datetime.now().timestamp()) + def upload_object_to_cas( self, file: base_manifest.BaseManifestPath, @@ -315,7 +280,7 @@ def upload_object_to_cas( s3_bucket: str, source_root: Path, s3_cas_prefix: str, - check_if_in_s3: bool = True, + s3_check_cache: S3CheckCache, progress_tracker: Optional[ProgressTracker] = None, ) -> Tuple[bool, int]: """ @@ -331,19 +296,32 @@ def upload_object_to_cas( is_uploaded = False file_size = local_path.stat().st_size - if check_if_in_s3 and self.file_already_uploaded(s3_bucket, s3_upload_key): + if s3_check_cache.get_entry(s3_key=f"{s3_bucket}/{s3_upload_key}"): logger.debug( - f"skipping {local_path} because it has already been uploaded to s3://{s3_bucket}/{s3_upload_key}" + f"skipping {local_path} because {s3_bucket}/{s3_upload_key} exists in the cache" ) return (is_uploaded, file_size) - self.upload_file_to_s3( - local_path=local_path, - s3_bucket=s3_bucket, - s3_upload_key=s3_upload_key, - progress_tracker=progress_tracker, + if self.file_already_uploaded(s3_bucket, s3_upload_key): + logger.debug( + f"skipping {local_path} because it has already been uploaded to s3://{s3_bucket}/{s3_upload_key}" + ) + else: + self.upload_file_to_s3( + local_path=local_path, + s3_bucket=s3_bucket, + s3_upload_key=s3_upload_key, + progress_tracker=progress_tracker, + ) + is_uploaded = True + + s3_check_cache.put_entry( + S3CheckCacheEntry( + s3_key=f"{s3_bucket}/{s3_upload_key}", + last_seen_time=self._get_current_timestamp(), + ) ) - is_uploaded = True + return (is_uploaded, file_size) def upload_file_to_s3( @@ -436,7 +414,13 @@ def upload_file_to_s3( **COMMON_ERROR_GUIDANCE_FOR_S3, 403: ( "Forbidden or Access denied. Please check your AWS credentials, and ensure that " - "your AWS IAM Role or User has the 's3:GetObject' permission for this bucket." + "your AWS IAM Role or User has the 's3:PutObject' permission for this bucket. " + ) + if "kms:" not in str(exc) + else ( + "Forbidden or Access denied. Please check your AWS credentials and Job Attachments S3 bucket " + "encryption settings. If a customer-managed KMS key is set, confirm that your AWS IAM Role or " + "User has the 'kms:GenerateDataKey' and 'kms:DescribeKey' permissions for the key used to encrypt the bucket." ), 404: "Not found. Please check your bucket name and object key, and ensure that they exist in the AWS account.", } @@ -496,46 +480,6 @@ def _upload_part( etag = response["ETag"] return {"ETag": etag, "PartNumber": part_number} - def filter_objects_to_upload(self, bucket: str, prefix: str, upload_set: set[str]) -> set[str]: - """ - Makes a paginated list-objects request to S3 to get all objects in the given prefix. - Given the set of files to be uploaded, returns which objects do not exist in S3. - """ - try: - paginator = self._s3.get_paginator("list_objects_v2") - page_iterator = paginator.paginate( - Bucket=bucket, - Prefix=prefix, - ) - - for page in page_iterator: - contents = page.get("Contents", None) - if contents is None: - break - for content in contents: - upload_set.discard(content["Key"].split("/")[-1]) - if len(upload_set) == 0: - break - except ClientError as exc: - status_code = int(exc.response["ResponseMetadata"]["HTTPStatusCode"]) - status_code_guidance = { - **COMMON_ERROR_GUIDANCE_FOR_S3, - 403: ( - "Forbidden or Access denied. Please check your AWS credentials, and ensure that " - "your AWS IAM Role or User has the 's3:ListBucket' permission for this bucket." - ), - 404: "Not found. Please ensure that the bucket and key/prefix exists.", - } - raise JobAttachmentsS3ClientError( - action="listing bucket contents", - status_code=status_code, - bucket_name=bucket, - key_or_prefix=prefix, - message=f"{status_code_guidance.get(status_code, '')} {str(exc)}", - ) from exc - - return upload_set - def file_already_uploaded(self, bucket: str, key: str) -> bool: """ Check whether the file has already been uploaded by doing a head-object call. @@ -585,7 +529,13 @@ def upload_bytes_to_s3( **COMMON_ERROR_GUIDANCE_FOR_S3, 403: ( "Forbidden or Access denied. Please check your AWS credentials, and ensure that " - "your AWS IAM Role or User has the 's3:PutObject' permission for this bucket." + "your AWS IAM Role or User has the 's3:PutObject' permission for this bucket. " + ) + if "kms:" not in str(exc) + else ( + "Forbidden or Access denied. Please check your AWS credentials and Job Attachments S3 bucket " + "encryption settings. If a customer-managed KMS key is set, confirm that your AWS IAM Role or " + "User has the 'kms:GenerateDataKey' and 'kms:DescribeKey' permissions for the key used to encrypt the bucket." ), 404: "Not found. Please check your bucket name, and ensure that it exists in the AWS account.", } @@ -1061,6 +1011,7 @@ def upload_assets( self, manifests: list[AssetRootManifest], on_uploading_assets: Optional[Callable[[Any], bool]] = None, + s3_check_cache_dir: Optional[str] = None, ) -> tuple[SummaryStatistics, Attachments]: """ Uploads all the files for provided manifests and manifests themselves to S3. @@ -1110,6 +1061,7 @@ def upload_assets( source_root=Path(asset_root_manifest.root_path), file_system_location_name=asset_root_manifest.file_system_location_name, progress_tracker=progress_tracker, + s3_check_cache_dir=s3_check_cache_dir, ) manifest_properties.inputManifestPath = partial_manifest_key manifest_properties.inputManifestHash = asset_manifest_hash diff --git a/test/integ/deadline_job_attachments/test_job_attachments.py b/test/integ/deadline_job_attachments/test_job_attachments.py index 72fa8195..43a7c1cd 100644 --- a/test/integ/deadline_job_attachments/test_job_attachments.py +++ b/test/integ/deadline_job_attachments/test_job_attachments.py @@ -84,6 +84,7 @@ def __init__( self.deadline_client = self.job_attachment_resources.deadline_client self.hash_cache_dir = tmp_path_factory.mktemp("hash_cache") + self.s3_cache_dir = tmp_path_factory.mktemp("s3_check_cache") self.session = boto3.Session() self.deadline_endpoint = os.getenv( "AWS_ENDPOINT_URL_DEADLINE", @@ -156,7 +157,11 @@ def upload_input_files_assets_not_in_cas(job_attachment_test: JobAttachmentTest) hash_cache_dir=str(job_attachment_test.hash_cache_dir), on_preparing_to_submit=mock_on_preparing_to_submit, ) - asset_manager.upload_assets(manifests, on_uploading_assets=mock_on_uploading_files) + asset_manager.upload_assets( + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), + ) # THEN scene_ma_s3_path = ( @@ -231,7 +236,9 @@ def upload_input_files_one_asset_in_cas( ) (_, attachments) = asset_manager.upload_assets( - manifests, on_uploading_assets=mock_on_uploading_files + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), ) # THEN @@ -316,7 +323,9 @@ def test_upload_input_files_all_assets_in_cas( on_preparing_to_submit=mock_on_preparing_to_submit, ) (_, attachments) = asset_manager.upload_assets( - manifests, on_uploading_assets=mock_on_uploading_files + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), ) # THEN @@ -1070,7 +1079,9 @@ def upload_input_files_no_input_paths( on_preparing_to_submit=mock_on_preparing_to_submit, ) (_, attachments) = asset_manager.upload_assets( - manifests, on_uploading_assets=mock_on_uploading_files + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), ) # THEN @@ -1136,7 +1147,9 @@ def test_upload_input_files_no_download_paths(job_attachment_test: JobAttachment on_preparing_to_submit=mock_on_preparing_to_submit, ) (_, attachments) = asset_manager.upload_assets( - manifests, on_uploading_assets=mock_on_uploading_files + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), ) # THEN @@ -1248,7 +1261,11 @@ def test_upload_bucket_wrong_account(external_bucket: str, job_attachment_test: hash_cache_dir=str(job_attachment_test.hash_cache_dir), on_preparing_to_submit=mock_on_preparing_to_submit, ) - asset_manager.upload_assets(manifests, on_uploading_assets=mock_on_uploading_files) + asset_manager.upload_assets( + manifests, + on_uploading_assets=mock_on_uploading_files, + s3_check_cache_dir=str(job_attachment_test.s3_cache_dir), + ) @pytest.mark.integ diff --git a/test/unit/deadline_client/cli/test_cli_config.py b/test/unit/deadline_client/cli/test_cli_config.py index 2d56740e..168860e0 100644 --- a/test/unit/deadline_client/cli/test_cli_config.py +++ b/test/unit/deadline_client/cli/test_cli_config.py @@ -35,7 +35,7 @@ def test_cli_config_show_defaults(fresh_deadline_config): assert fresh_deadline_config in result.output # Assert the expected number of settings - assert len(settings.keys()) == 17 + assert len(settings.keys()) == 16 for setting_name in settings.keys(): assert setting_name in result.output @@ -102,7 +102,6 @@ def test_cli_config_show_modified_config(fresh_deadline_config): config.set_setting("settings.log_level", "DEBUG") config.set_setting("telemetry.opt_out", "True") config.set_setting("telemetry.identifier", "user-id-123abc-456def") - config.set_setting("settings.list_object_threshold", "200") config.set_setting("settings.multipart_upload_chunk_size", "10000000") config.set_setting("settings.multipart_upload_max_workers", "16") config.set_setting("settings.small_file_threshold_multiplier", "15") diff --git a/test/unit/deadline_job_attachments/caches/__init__.py b/test/unit/deadline_job_attachments/caches/__init__.py new file mode 100644 index 00000000..8d929cc8 --- /dev/null +++ b/test/unit/deadline_job_attachments/caches/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/unit/deadline_job_attachments/caches/test_caches.py b/test/unit/deadline_job_attachments/caches/test_caches.py new file mode 100644 index 00000000..0b3130cc --- /dev/null +++ b/test/unit/deadline_job_attachments/caches/test_caches.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import os +from datetime import datetime +from sqlite3 import OperationalError +from unittest.mock import patch + +import pytest + +import deadline +from deadline.job_attachments.asset_manifests import HashAlgorithm +from deadline.job_attachments.exceptions import JobAttachmentsError +from deadline.job_attachments.caches import ( + CacheDB, + HashCache, + HashCacheEntry, + S3CheckCache, + S3CheckCacheEntry, +) + + +class TestCacheDB: + """ + Tests for the CacheDB abstract base class + """ + + def test_get_default_cache_db_file_dir_env_var_path_exists(self, tmpdir): + """ + Tests that when an environment variable exists, it uses that path for the hash cache + """ + expected_path = tmpdir.join(".deadline").join("job_attachments") + with patch("os.environ.get", side_effect=[tmpdir]): + assert CacheDB.get_default_cache_db_file_dir() == expected_path + + def test_init_empty_path_no_default_throws_error(self): + """ + Tests that when no cache file path is given, the default is used. + """ + os.environ.pop("APPDATA", None) + os.environ.pop("HOME", None) + os.environ.pop("XDG_CONFIG_HOME", None) + + with pytest.raises(JobAttachmentsError): + CacheDB("name", "table", "query") + + def test_enter_bad_cache_path_throws_error(self, tmpdir): + """ + Tests that an error is raised when a bad path is provided to the CacheDB constructor + """ + with pytest.raises(JobAttachmentsError) as err: + cdb = CacheDB("name", "table", "query", tmpdir) + cdb.cache_dir = "/some/bad/path" + with cdb: + assert ( + False + ), "Context manager should throw a JobAttachmentsError, this assert should not be reached" + assert isinstance(err.value.__cause__, OperationalError) + + @pytest.mark.parametrize( + "cache_name, table_name, create_query", + [ + pytest.param("", "table", "query"), + pytest.param("name", "", "query"), + pytest.param("name", "table", ""), + ], + ) + def test_init_throws_error_on_empty_strings(self, cache_name, table_name, create_query): + """Tests that a JobAttachmentsError is raised if init args are empty""" + with pytest.raises(JobAttachmentsError): + CacheDB(cache_name, table_name, create_query) + + +class TestHashCache: + """ + Tests for the local Hash Cache + """ + + def test_init_empty_path(self, tmpdir): + """ + Tests that when no cache file path is given, the default is used. + """ + with patch( + f"{deadline.__package__}.job_attachments.caches.CacheDB.get_default_cache_db_file_dir", + side_effect=[tmpdir], + ): + hc = HashCache() + assert hc.cache_dir == tmpdir.join(f"{HashCache.CACHE_NAME}.db") + + def test_get_entry_returns_valid_entry(self, tmpdir): + """ + Tests that a valid entry is returned when it exists in the cache already + """ + # GIVEN + cache_dir = tmpdir.mkdir("cache") + expected_entry = HashCacheEntry( + file_path="file", + hash_algorithm=HashAlgorithm.XXH128, + file_hash="hash", + last_modified_time="1234.5678", + ) + + # WHEN + with HashCache(cache_dir) as hc: + hc.put_entry(expected_entry) + actual_entry = hc.get_entry("file", HashAlgorithm.XXH128) + + # THEN + assert actual_entry == expected_entry + + def test_enter_sqlite_import_error(self, tmpdir): + """ + Tests that the cache doesn't throw errors when the SQLite module can't be found + """ + with patch.dict("sys.modules", {"sqlite3": None}): + new_dir = tmpdir.join("does_not_exist") + hc = HashCache(new_dir) + assert not os.path.exists(new_dir) + with hc: + assert hc.get_entry("/no/file", HashAlgorithm.XXH128) is None + hc.put_entry( + HashCacheEntry( + file_path="/no/file", + hash_algorithm=HashAlgorithm.XXH128, + file_hash="abc", + last_modified_time="1234.56", + ) + ) + assert hc.get_entry("/no/file", HashAlgorithm.XXH128) is None + + +class TestS3CheckCache: + """ + Tests for the local S3 Check Hash + """ + + def test_init_empty_path(self, tmpdir): + """ + Tests that when no cache file path is given, the default is used. + """ + with patch( + f"{deadline.__package__}.job_attachments.caches.CacheDB.get_default_cache_db_file_dir", + side_effect=[tmpdir], + ): + s3c = S3CheckCache() + assert s3c.cache_dir == tmpdir.join(f"{S3CheckCache.CACHE_NAME}.db") + + def test_get_entry_returns_valid_entry(self, tmpdir): + """ + Tests that a valid entry is returned when it exists in the cache already + """ + # GIVEN + cache_dir = tmpdir.mkdir("cache") + expected_entry = S3CheckCacheEntry( + s3_key="bucket/Data/somehash", + last_seen_time=str(datetime.now().timestamp()), + ) + + # WHEN + with S3CheckCache(cache_dir) as s3c: + s3c.put_entry(expected_entry) + actual_entry = s3c.get_entry("bucket/Data/somehash") + + # THEN + assert actual_entry == expected_entry + + def test_get_entry_returns_none_with_expired_entry(self, tmpdir): + """ + Tests that nothing is returned when an existing entry is expired + """ + # GIVEN + cache_dir = tmpdir.mkdir("cache") + expected_entry = S3CheckCacheEntry( + s3_key="bucket/Data/somehash", + last_seen_time="123.456", # a looong time ago + ) + + # WHEN + with S3CheckCache(cache_dir) as s3c: + s3c.put_entry(expected_entry) + actual_entry = s3c.get_entry("bucket/Data/somehash") + + # THEN + assert actual_entry is None + + def test_enter_sqlite_import_error(self, tmpdir): + """ + Tests that the cache doesn't throw errors when the SQLite module can't be found + """ + with patch.dict("sys.modules", {"sqlite3": None}): + new_dir = tmpdir.join("does_not_exist") + s3c = S3CheckCache(new_dir) + assert not os.path.exists(new_dir) + with s3c: + assert s3c.get_entry("bucket/Data/somehash") is None + s3c.put_entry( + S3CheckCacheEntry( + s3_key="bucket/Data/somehash", + last_seen_time=str(datetime.now().timestamp()), + ) + ) + assert s3c.get_entry("bucket/Data/somehash") is None diff --git a/test/unit/deadline_job_attachments/test_asset_sync.py b/test/unit/deadline_job_attachments/test_asset_sync.py index a5ac5828..09cf0f77 100644 --- a/test/unit/deadline_job_attachments/test_asset_sync.py +++ b/test/unit/deadline_job_attachments/test_asset_sync.py @@ -19,7 +19,10 @@ from deadline.job_attachments.asset_sync import AssetSync from deadline.job_attachments.os_file_permission import PosixFileSystemPermissionSettings from deadline.job_attachments.download import _progress_logger -from deadline.job_attachments.exceptions import Fus3ExecutableMissingError +from deadline.job_attachments.exceptions import ( + Fus3ExecutableMissingError, + JobAttachmentsS3ClientError, +) from deadline.job_attachments.models import ( Attachments, Job, @@ -274,6 +277,71 @@ def test_sync_inputs_successful( } ] + @pytest.mark.parametrize( + ("job_fixture_name"), + [ + ("default_job"), + ], + ) + @pytest.mark.parametrize( + ("s3_settings_fixture_name"), + [ + ("default_job_attachment_s3_settings"), + ], + ) + def test_sync_inputs_404_error( + self, + tmp_path: Path, + default_queue: Queue, + job_fixture_name: str, + s3_settings_fixture_name: str, + request: pytest.FixtureRequest, + ): + """Asserts that a specific error message is raised when getting 404 errors synching inputs""" + # GIVEN + download_exception = JobAttachmentsS3ClientError( + action="get-object", + status_code=404, + bucket_name="test bucket", + key_or_prefix="test-key.xxh128", + message="File not found", + ) + job: Job = request.getfixturevalue(job_fixture_name) + s3_settings: JobAttachmentS3Settings = request.getfixturevalue(s3_settings_fixture_name) + default_queue.jobAttachmentSettings = s3_settings + session_dir = str(tmp_path) + dest_dir = "assetroot-27bggh78dd2b568ab123" + local_root = str(Path(session_dir) / dest_dir) + assert job.attachments + + # WHEN + with patch( + f"{deadline.__package__}.job_attachments.asset_sync.get_manifest_from_s3", + side_effect=[f"{local_root}/manifest.json"], + ), patch("builtins.open", mock_open(read_data="test_manifest_file")), patch( + f"{deadline.__package__}.job_attachments.asset_sync.decode_manifest", + side_effect=["test_manifest_data"], + ), patch( + f"{deadline.__package__}.job_attachments.asset_sync._get_unique_dest_dir_name", + side_effect=[dest_dir], + ), patch( + f"{deadline.__package__}.job_attachments.asset_sync.download_files_from_manifests", + side_effect=download_exception, + ): + with pytest.raises(JobAttachmentsS3ClientError) as excinfo: + self.default_asset_sync.sync_inputs( + s3_settings, + job.attachments, + default_queue.queueId, + job.jobId, + tmp_path, + ) + + # THEN + assert "usually located in the home directory (~/.deadline/cache/s3_check_cache.db)" in str( + excinfo + ) + @pytest.mark.parametrize( ("s3_settings_fixture_name"), [ diff --git a/test/unit/deadline_job_attachments/test_hash_cache.py b/test/unit/deadline_job_attachments/test_hash_cache.py deleted file mode 100644 index f0585feb..00000000 --- a/test/unit/deadline_job_attachments/test_hash_cache.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -import os -from sqlite3 import OperationalError -from unittest.mock import patch - -import pytest - -import deadline -from deadline.job_attachments.asset_manifests import HashAlgorithm -from deadline.job_attachments.exceptions import JobAttachmentsError -from deadline.job_attachments.hash_cache import CACHE_FILE_NAME, HashCache -from deadline.job_attachments.models import HashCacheEntry - - -class TestHashCache: - """ - Tests for the local Hash Cache - """ - - def test_init_empty_path(self, tmpdir): - """ - Tests that when no cache file path is given, the default is used. - """ - with patch( - f"{deadline.__package__}.job_attachments.hash_cache._get_default_hash_cache_db_file_dir", - side_effect=[tmpdir], - ): - hc = HashCache() - assert hc.cache_dir == tmpdir.join(CACHE_FILE_NAME) - - def test_init_empty_path_no_default_throws_error(self): - """ - Tests that when no cache file path is given, the default is used. - """ - os.environ.pop("APPDATA", None) - os.environ.pop("HOME", None) - os.environ.pop("XDG_CONFIG_HOME", None) - - with pytest.raises(JobAttachmentsError): - HashCache() - assert False, "Constructor should raise an error, this assert should not be reached" - - def test_enter_bad_cache_path_throws_error(self, tmpdir): - """ - Tests that an error is raised when a bad path is provided to the HashCache constructor - """ - with pytest.raises(JobAttachmentsError) as err: - hc = HashCache(tmpdir) - hc.cache_dir = "/some/bad/path" - with hc: - assert ( - False - ), "Context manager should throw an execption, this assert should not be reached" - assert isinstance(err.value.__cause__, OperationalError) - - def test_get_entry_returns_valid_entry(self, tmpdir): - """ - Tests that a valid entry is returned when it exists in the cache already - """ - # GIVEN - cache_dir = tmpdir.mkdir("cache") - expected_entry = HashCacheEntry( - file_path="file", - hash_algorithm=HashAlgorithm.XXH128, - file_hash="hash", - last_modified_time="1234.5678", - ) - - # WHEN - with HashCache(cache_dir) as hc: - hc.put_entry(expected_entry) - actual_entry = hc.get_entry("file", HashAlgorithm.XXH128) - - # THEN - assert actual_entry == expected_entry - - def test_enter_sqlite_import_error(self, tmpdir): - """ - Tests that the hash cache doesn't throw errors when the SQLite module can't be found - """ - with patch.dict("sys.modules", {"sqlite3": None}): - new_dir = tmpdir.join("does_not_exist") - hc = HashCache(new_dir) - assert not os.path.exists(new_dir) - with hc: - assert hc.get_entry("/no/file", HashAlgorithm.XXH128) is None - hc.put_entry( - HashCacheEntry( - file_path="/no/file", - hash_algorithm=HashAlgorithm.XXH128, - file_hash="abc", - last_modified_time="1234.56", - ) - ) - assert hc.get_entry("/no/file", HashAlgorithm.XXH128) is None diff --git a/test/unit/deadline_job_attachments/test_upload.py b/test/unit/deadline_job_attachments/test_upload.py index ecec946d..5a2a60d0 100644 --- a/test/unit/deadline_job_attachments/test_upload.py +++ b/test/unit/deadline_job_attachments/test_upload.py @@ -29,6 +29,7 @@ HashAlgorithm, ManifestVersion, ) +from deadline.job_attachments.caches import HashCacheEntry, S3CheckCacheEntry from deadline.job_attachments.exceptions import ( AssetSyncError, JobAttachmentsS3ClientError, @@ -41,7 +42,6 @@ FileSystemLocation, FileSystemLocationType, ManifestProperties, - HashCacheEntry, JobAttachmentS3Settings, OperatingSystemFamily, PathFormat, @@ -194,6 +194,7 @@ def test_asset_management( (upload_summary_statistics, attachments) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=str(cache_dir), ) # Then @@ -361,6 +362,7 @@ def test_asset_management_windows_multi_root( (upload_summary_statistics, attachments) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -448,7 +450,7 @@ def test_asset_management_windows_multi_root( @mock_sts @pytest.mark.parametrize( - "num_additional_input_files", + "num_input_files", [ 1, 100, @@ -470,7 +472,7 @@ def test_asset_management_many_inputs( assert_expected_files_on_s3, caplog, manifest_version: ManifestVersion, - num_additional_input_files: int, + num_input_files: int, ): """ Test that the correct files get uploaded to S3 and the asset manifest @@ -486,10 +488,6 @@ def test_asset_management_many_inputs( asset_manifest_version=manifest_version, ) - num_input_files = ( - asset_manager.asset_uploader.list_object_threshold + num_additional_input_files - ) - with patch( f"{deadline.__package__}.job_attachments.upload.PathFormat.get_host_path_format", return_value=PathFormat.POSIX, @@ -540,6 +538,7 @@ def test_asset_management_many_inputs( (upload_summary_statistics, attachments) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -612,7 +611,7 @@ def test_asset_management_many_inputs( @mock_sts @pytest.mark.parametrize( - "num_additional_input_files", + "num_input_files", [ 1, 100, @@ -631,7 +630,7 @@ def test_asset_management_many_inputs_with_same_hash( farm_id, queue_id, manifest_version: ManifestVersion, - num_additional_input_files: int, + num_input_files: int, ): """ Test that the asset management can handle many input files with the same hash. @@ -647,10 +646,6 @@ def test_asset_management_many_inputs_with_same_hash( asset_manifest_version=manifest_version, ) - num_input_files = ( - asset_manager.asset_uploader.list_object_threshold + num_additional_input_files - ) - # Given with patch( f"{deadline.__package__}.job_attachments.upload.PathFormat.get_host_path_format", @@ -661,6 +656,9 @@ def test_asset_management_many_inputs_with_same_hash( ), patch( f"{deadline.__package__}.job_attachments.upload.hash_file", side_effect=lambda *args, **kwargs: "samehash", + ), patch( + f"{deadline.__package__}.job_attachments.upload.NUM_UPLOAD_WORKERS", + 1, # Change the number of thread workers to 1 to get consistent tests ): mock_on_preparing_to_submit = MagicMock(return_value=True) mock_on_uploading_assets = MagicMock(return_value=True) @@ -698,6 +696,7 @@ def test_asset_management_many_inputs_with_same_hash( (upload_summary_statistics, _) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -821,6 +820,7 @@ def mock_hash_file(file_path: str, hash_alg: HashAlgorithm): (upload_summary_statistics, _) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -864,7 +864,7 @@ def mock_hash_file(file_path: str, hash_alg: HashAlgorithm): @mock_sts @pytest.mark.parametrize( - "num_additional_input_files", + "num_input_files", [ 1, 100, @@ -885,7 +885,7 @@ def test_asset_management_no_outputs_large_number_of_inputs_already_uploaded( assert_expected_files_on_s3, caplog, manifest_version: ManifestVersion, - num_additional_input_files: int, + num_input_files: int, ): """ Test the input files that have already been uploaded to S3 are skipped. @@ -898,10 +898,6 @@ def test_asset_management_no_outputs_large_number_of_inputs_already_uploaded( asset_manifest_version=manifest_version, ) - num_input_files = ( - asset_manager.asset_uploader.list_object_threshold + num_additional_input_files - ) - with patch( f"{deadline.__package__}.job_attachments.upload.PathFormat.get_host_path_format", return_value=PathFormat.POSIX, @@ -968,6 +964,7 @@ def test_asset_management_no_outputs_large_number_of_inputs_already_uploaded( (upload_summary_statistics, _) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -1077,6 +1074,7 @@ def test_asset_management_no_inputs( (upload_summary_statistics, attachments) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -1265,7 +1263,6 @@ def test_asset_uploader_constructor(self, fresh_deadline_config): Test that when the asset uploader is created, the instance variables are correctly set. """ uploader = S3AssetUploader() - assert uploader.list_object_threshold == 100 assert uploader.multipart_upload_chunk_size == 8 * (1024**2) assert uploader.multipart_upload_max_workers == 10 assert uploader.small_file_threshold == 20 * 8 * (1024**2) @@ -1276,7 +1273,7 @@ def test_asset_uploader_constructor_with_non_integer_config_settings( """ Tests that when the asset uploader is created with non-integer config settings, an AssetSyncError is raised. """ - config.set_setting("settings.list_object_threshold", "!@#$") + config.set_setting("settings.multipart_upload_chunk_size", "!@#$") with pytest.raises(AssetSyncError) as err: _ = S3AssetUploader() assert isinstance(err.value.__cause__, ValueError) @@ -1285,11 +1282,6 @@ def test_asset_uploader_constructor_with_non_integer_config_settings( @pytest.mark.parametrize( "setting_name, invalid_value", [ - pytest.param( - "list_object_threshold", - "0", - id="Invalid list_object_threshold value: 0", - ), pytest.param( "multipart_upload_chunk_size", "-100", @@ -1354,39 +1346,6 @@ def test_file_already_uploaded_bucket_in_different_account(self): "and your AWS IAM Role or User has the 's3:ListBucket' permission for this bucket." ) in str(err.value) - @mock_sts - def test_filter_objects_to_upload_bucket_in_different_account(self): - """ - Test that the appropriate error is raised when checking if a file has already been uploaded, but the bucket - is in an account that is different from the uploader's account. - """ - s3 = boto3.client("s3") - stubber = Stubber(s3) - stubber.add_client_error( - "list_objects_v2", - service_error_code="AccessDenied", - service_message="Access Denied", - http_status_code=403, - ) - - uploader = S3AssetUploader() - - uploader._s3 = s3 - - with stubber: - with pytest.raises(JobAttachmentsS3ClientError) as err: - uploader.filter_objects_to_upload( - self.job_attachment_s3_settings.s3BucketName, "test_prefix", {"test_key"} - ) - assert isinstance(err.value.__cause__, ClientError) - assert ( - err.value.__cause__.response["ResponseMetadata"]["HTTPStatusCode"] == 403 # type: ignore[attr-defined] - ) - assert ( - "Error listing bucket contents in bucket 'test-bucket', Target key or prefix: 'test_prefix', " - "HTTP Status Code: 403, Forbidden or Access denied. " - ) in str(err.value) - @mock_sts def test_upload_bytes_to_s3_bucket_in_different_account(self): """ @@ -1512,6 +1471,45 @@ def test_upload_file_to_s3_bucket_in_different_account(self, tmp_path: Path): ) in str(err.value) assert (f"(Failed to upload {str(file)})") in str(err.value) + @mock_sts + def test_upload_file_to_s3_bucket_has_kms_permissions_error(self, tmp_path: Path): + """ + Test that the appropriate error is raised when uploading files, but the bucket + is encrypted with a KMS key and the user doesn't have access to the key. + """ + s3 = boto3.client("s3") + stubber = Stubber(s3) + + # This is the error that's surfaced when a bucket is in a different account than expected. + stubber.add_client_error( + "create_multipart_upload", + service_error_code="AccessDenied", + service_message="An error occurred (AccessDenied) when calling the PutObject operation: User: arn:aws:sts:::assumed-role/ is not authorized to perform: kms:GenerateDataKey on resource: arn:aws:kms:us-west-2::key/ because no identity-based policy allows the kms:GenerateDataKey action", + http_status_code=403, + ) + + uploader = S3AssetUploader() + + uploader._s3 = s3 + + file = tmp_path / "test_file" + file.write_text("") + + with stubber: + with pytest.raises(JobAttachmentsS3ClientError) as err: + uploader.upload_file_to_s3( + file, self.job_attachment_s3_settings.s3BucketName, "test_key" + ) + assert isinstance(err.value.__cause__, ClientError) + assert ( + err.value.__cause__.response["ResponseMetadata"]["HTTPStatusCode"] == 403 # type: ignore[attr-defined] + ) + assert ( + "If a customer-managed KMS key is set, confirm that your AWS IAM Role or " + "User has the 'kms:GenerateDataKey' and 'kms:DescribeKey' permissions for the key used to encrypt the bucket." + ) in str(err.value) + assert (f"(Failed to upload {str(file)})") in str(err.value) + @mock_sts def test_upload_file_to_s3_empty_file(self, tmp_path: Path): """ @@ -1678,6 +1676,7 @@ def test_asset_management_input_not_exists(self, farm_id, queue_id, tmpdir, capl (upload_summary_statistics, _) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=cache_dir, ) # Then @@ -1801,6 +1800,7 @@ def test_manage_assets_with_symlinks( (upload_summary_statistics, attachments) = asset_manager.upload_assets( manifests=asset_root_manifests, on_uploading_assets=mock_on_uploading_assets, + s3_check_cache_dir=str(cache_dir), ) # THEN @@ -2333,6 +2333,116 @@ def test_separate_files_by_size( ) assert actual_queues == expected_queues + @pytest.mark.parametrize( + "manifest_version", + [ + ManifestVersion.v2023_03_03, + ], + ) + def test_upload_object_to_cas_skips_upload_with_cache( + self, tmpdir, farm_id, queue_id, manifest_version, default_job_attachment_s3_settings + ): + """ + Tests that objects are not uploaded to S3 if there is a corresponding entry in the S3CheckCache + """ + # Given + asset_root = tmpdir.mkdir("test-root") + test_file = asset_root.join("test-file.txt") + test_file.write("stuff") + asset_manager = S3AssetManager( + farm_id=farm_id, + queue_id=queue_id, + job_attachment_settings=self.job_attachment_s3_settings, + asset_manifest_version=manifest_version, + ) + s3_key = f"{default_job_attachment_s3_settings.s3BucketName}/prefix/test-hash" + test_entry = S3CheckCacheEntry(s3_key, "123.45") + s3_cache = MagicMock() + s3_cache.get_entry.return_value = test_entry + + # When + with patch.object( + asset_manager.asset_uploader, + "_get_current_timestamp", + side_effect=["345.67"], + ): + (is_uploaded, file_size) = asset_manager.asset_uploader.upload_object_to_cas( + file=BaseManifestPath(path="test-file.txt", hash="test-hash", size=5, mtime=1), + hash_algorithm=HashAlgorithm.XXH128, + s3_bucket=default_job_attachment_s3_settings.s3BucketName, + source_root=Path(asset_root), + s3_cas_prefix="prefix", + s3_check_cache=s3_cache, + ) + + # Then + assert not is_uploaded + assert file_size == 5 + s3_cache.put_entry.assert_not_called() + + @pytest.mark.parametrize( + "manifest_version", + [ + ManifestVersion.v2023_03_03, + ], + ) + def test_upload_object_to_cas_adds_cache_entry( + self, + tmpdir, + farm_id, + queue_id, + manifest_version, + default_job_attachment_s3_settings, + assert_expected_files_on_s3, + ): + """ + Tests that when an object is added to the CAS, an S3 cache entry is added. + """ + # Given + asset_root = tmpdir.mkdir("test-root") + test_file = asset_root.join("test-file.txt") + test_file.write("stuff") + asset_manager = S3AssetManager( + farm_id=farm_id, + queue_id=queue_id, + job_attachment_settings=self.job_attachment_s3_settings, + asset_manifest_version=manifest_version, + ) + s3_key = f"{default_job_attachment_s3_settings.s3BucketName}/prefix/test-hash" + s3_cache = MagicMock() + s3_cache.get_entry.return_value = None + expected_new_entry = S3CheckCacheEntry(s3_key, "345.67") + + # When + with patch.object( + asset_manager.asset_uploader, + "_get_current_timestamp", + side_effect=["345.67"], + ): + (is_uploaded, file_size) = asset_manager.asset_uploader.upload_object_to_cas( + file=BaseManifestPath(path="test-file.txt", hash="test-hash", size=5, mtime=1), + hash_algorithm=HashAlgorithm.XXH128, + s3_bucket=default_job_attachment_s3_settings.s3BucketName, + source_root=Path(asset_root), + s3_cas_prefix="prefix", + s3_check_cache=s3_cache, + ) + + # Then + assert is_uploaded + assert file_size == 5 + s3_cache.put_entry.assert_called_once_with(expected_new_entry) + + s3 = boto3.Session(region_name="us-west-2").resource( + "s3" + ) # pylint: disable=invalid-name + bucket = s3.Bucket(self.job_attachment_s3_settings.s3BucketName) + + assert_expected_files_on_s3( + bucket, + expected_files={"prefix/test-hash"}, + ) + def assert_progress_report_last_callback( num_input_files: int, diff --git a/test/unit/deadline_job_attachments/test_utils.py b/test/unit/deadline_job_attachments/test_utils.py index b86e22f5..b33fb338 100644 --- a/test/unit/deadline_job_attachments/test_utils.py +++ b/test/unit/deadline_job_attachments/test_utils.py @@ -2,25 +2,15 @@ from pathlib import Path import sys -from unittest.mock import patch import pytest from deadline.job_attachments._utils import ( - _get_default_hash_cache_db_file_dir, _is_relative_to, ) class TestUtils: - def test_get_default_hash_cache_db_file_dir_env_var_path_exists(self, tmpdir): - """ - Tests that when an environment variable exists, it uses that path for the hash cache - """ - expected_path = tmpdir.join(".deadline").join("job_attachments") - with patch("os.environ.get", side_effect=[tmpdir]): - assert _get_default_hash_cache_db_file_dir() == expected_path - @pytest.mark.skipif( sys.platform == "win32", reason="This test is for paths in POSIX path format and will be skipped on Windows.",