Skip to content

Commit

Permalink
Add no-overwrite option to s3 cp and s3 mv commands
Browse files Browse the repository at this point in the history
  • Loading branch information
Illia Batozskyi committed Apr 14, 2021
1 parent 4172ada commit 2bda3e5
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 51 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/feature-s3-1341.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "feature",
"category": "``s3``",
"description": "Add ``--no-overwrite`` option to ``cp`` and ``mv`` commands"
}
143 changes: 110 additions & 33 deletions awscli/customizations/s3/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,17 @@
import logging
import fnmatch
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import reduce

from botocore.exceptions import ClientError
from awscli.customizations.s3.utils import split_s3_bucket_key

from awscli.customizations.utils import uni_print

LOG = logging.getLogger(__name__)


def create_filter(parameters):
"""Given the CLI parameters dict, create a Filter object."""
# We need to evaluate all the filters based on the source
# directory.
if parameters['filters']:
cli_filters = parameters['filters']
real_filters = []
for filter_type, filter_pattern in cli_filters:
real_filters.append((filter_type.lstrip('-'),
filter_pattern))
source_location = parameters['src']
if source_location.startswith('s3://'):
# This gives us (bucket, keyname) and we want
# the bucket to be the root dir.
src_rootdir = _get_s3_root(source_location,
parameters['dir_op'])
else:
src_rootdir = _get_local_root(parameters['src'], parameters['dir_op'])

destination_location = parameters['dest']
if destination_location.startswith('s3://'):
dst_rootdir = _get_s3_root(parameters['dest'],
parameters['dir_op'])
else:
dst_rootdir = _get_local_root(parameters['dest'],
parameters['dir_op'])

return Filter(real_filters, src_rootdir, dst_rootdir)
else:
return Filter({}, None, None)


def _get_s3_root(source_location, dir_op):
# Obtain the bucket and the key.
bucket, key = split_s3_bucket_key(source_location)
Expand All @@ -73,6 +45,64 @@ def _get_local_root(source_location, dir_op):
return rootdir


class FilterRunner:
def __init__(self, parameters, client, out_file=None):
self._parameters = parameters
self._client = client
self._out_file = out_file
if self._out_file is None:
self._out_file = sys.stdout
self.filters = self._create_filters()

def call(self, file_infos):
return reduce(
lambda files, filtering: filtering.call(files),
self.filters, file_infos)

def _create_filters(self):
return [self._create_filter(), self._create_no_overwrite_filter()]

def _create_filter(self):
"""Given the CLI parameters dict, create a Filter object."""
# We need to evaluate all the filters based on the source
# directory.
parameters = self._parameters
if parameters.get('filters'):
cli_filters = parameters['filters']
real_filters = []
for filter_type, filter_pattern in cli_filters:
real_filters.append((filter_type.lstrip('-'),
filter_pattern))
source_location = parameters['src']
if source_location.startswith('s3://'):
# This gives us (bucket, keyname) and we want
# the bucket to be the root dir.
src_rootdir = _get_s3_root(source_location,
parameters['dir_op'])
else:
src_rootdir = _get_local_root(parameters['src'],
parameters['dir_op'])

destination_location = parameters['dest']
if destination_location.startswith('s3://'):
dst_rootdir = _get_s3_root(parameters['dest'],
parameters['dir_op'])
else:
dst_rootdir = _get_local_root(parameters['dest'],
parameters['dir_op'])

return Filter(real_filters, src_rootdir, dst_rootdir)
else:
return Filter({}, None, None)

def _create_no_overwrite_filter(self):
if self._parameters.get('no_overwrite', False):
return NoOverwriteFilter(
self._client, self._parameters, self._out_file)
else:
return Filter({}, None, None)


class Filter(object):
"""
This is a universal exclude/include filter.
Expand Down Expand Up @@ -151,3 +181,50 @@ def _match_pattern(self, pattern, file_info):
LOG.debug("%s did not match %s filter: %s",
file_path, pattern_type, path_pattern)
return file_status


class NoOverwriteFilter:
FILE_EXIST_FORMAT = 'Object/file %s already exists.\n'
FILE_SKIPPED_FORMAT = (
'Object %s skipped because of such head-object response "%s".\n'
)

def __init__(self, client, cli_params, out_file=None):
self._client = client
self._object_checker = self._get_object_checker(cli_params['dest'])
self._out_file = out_file
if self._out_file is None:
self._out_file = sys.stdout

def call(self, file_infos):
with ThreadPoolExecutor(max_workers=7) as executor:
yield from filter(
bool, executor.map(self._object_checker, file_infos))

def _get_object_checker(self, dest):
if dest.startswith('s3://'):
return self._check_s3_object_exists
return self._check_local_object_exists

def _check_s3_object_exists(self, fileinfo):
bucket, key = split_s3_bucket_key(fileinfo.dest)
try:
self._client.head_object(Bucket=bucket, Key=key)
self._print_to_out_file(self.FILE_EXIST_FORMAT % fileinfo.dest)
return False
except ClientError as e:
if e.response['Error']['Code'] == '404':
return fileinfo
self._print_to_out_file(
self.FILE_SKIPPED_FORMAT % (fileinfo.dest, e)
)
return False

def _check_local_object_exists(self, fileinfo):
if os.path.exists(fileinfo.dest):
self._print_to_out_file(self.FILE_EXIST_FORMAT % fileinfo.dest)
return False
return fileinfo

def _print_to_out_file(self, statement):
uni_print(statement, self._out_file)
29 changes: 20 additions & 9 deletions awscli/customizations/s3/subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from awscli.customizations.s3.fileformat import FileFormat
from awscli.customizations.s3.filegenerator import FileGenerator
from awscli.customizations.s3.fileinfo import FileInfo
from awscli.customizations.s3.filters import create_filter
from awscli.customizations.s3.filters import FilterRunner
from awscli.customizations.s3.s3handler import S3TransferHandlerFactory
from awscli.customizations.s3.utils import find_bucket_key, AppendFilter, \
find_dest_path_comp_key, human_readable_size, \
Expand Down Expand Up @@ -456,6 +456,14 @@
)
}


NO_OVERWRITE = {
'name': 'no-overwrite', 'action': 'store_true',
'help_text': ('Skip object if object with the same name '
'already exists in the destination')
}


TRANSFER_ARGS = [DRYRUN, QUIET, INCLUDE, EXCLUDE, ACL,
FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE,
SSE, SSE_C, SSE_C_KEY, SSE_KMS_KEY_ID, SSE_C_COPY_SOURCE,
Expand Down Expand Up @@ -757,7 +765,7 @@ class CpCommand(S3TransferCommand):
ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True,
'synopsis': USAGE}] + TRANSFER_ARGS + \
[METADATA, COPY_PROPS, METADATA_DIRECTIVE, EXPECTED_SIZE,
RECURSIVE]
RECURSIVE, NO_OVERWRITE]


class MvCommand(S3TransferCommand):
Expand All @@ -768,7 +776,8 @@ class MvCommand(S3TransferCommand):
"or <S3Uri> <S3Uri>"
ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True,
'synopsis': USAGE}] + TRANSFER_ARGS +\
[METADATA, COPY_PROPS, METADATA_DIRECTIVE, RECURSIVE]
[METADATA, COPY_PROPS, METADATA_DIRECTIVE,
RECURSIVE, NO_OVERWRITE]


class RmCommand(S3TransferCommand):
Expand Down Expand Up @@ -937,7 +946,8 @@ def create_instructions(self):
"""
if self.needs_filegenerator():
self.instructions.append('file_generator')
if self.parameters.get('filters'):
if self.parameters.get('filters') \
or self.parameters.get('no_overwrite'):
self.instructions.append('filters')
if self.cmd == 'sync':
self.instructions.append('comparator')
Expand Down Expand Up @@ -1056,34 +1066,35 @@ def run(self):
sync_strategies = self.choose_sync_strategies()

command_dict = {}
filter_runner = FilterRunner(self.parameters, self._client)
if self.cmd == 'sync':
command_dict = {'setup': [files, rev_files],
'file_generator': [file_generator,
rev_generator],
'filters': [create_filter(self.parameters),
create_filter(self.parameters)],
'filters': [filter_runner, filter_runner],
'comparator': [Comparator(**sync_strategies)],
'file_info_builder': [file_info_builder],
's3_handler': [s3_transfer_handler]}
elif self.cmd == 'cp' and self.parameters['is_stream']:
command_dict = {'setup': [stream_file_info],
'filters': [filter_runner],
's3_handler': [s3_transfer_handler]}
elif self.cmd == 'cp':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'filters': [filter_runner],
'file_info_builder': [file_info_builder],
's3_handler': [s3_transfer_handler]}
elif self.cmd == 'rm':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'filters': [filter_runner],
'file_info_builder': [file_info_builder],
's3_handler': [s3_transfer_handler]}
elif self.cmd == 'mv':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'filters': [filter_runner],
'file_info_builder': [file_info_builder],
's3_handler': [s3_transfer_handler]}

Expand Down
73 changes: 73 additions & 0 deletions tests/integration/customizations/s3/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def assert_no_errors(self, p):
assert 'client error' not in p.stderr
assert 'server error' not in p.stderr

def fail(self, msg=None):
raise Exception(msg)


class TestMoveCommand(BaseS3IntegrationTest):
def test_mv_local_to_s3(self, files, s3_utils, shared_bucket):
Expand Down Expand Up @@ -1855,6 +1858,76 @@ def test_filter_no_resync(self, files, s3_utils, shared_bucket):
self.assert_no_files_would_be_uploaded(p)


class TestNoOverwriteFilter(BaseS3IntegrationTest):
def assert_skipped(self, p, files):
for filename in files:
assert '%s already exists' % filename in p.stdout

def assert_not_skipped(self, p, files):
for filename in files:
assert '%s already exists' % filename not in p.stdout

def test_not_overwrite_s3_objects(self, files, s3_utils, shared_bucket):
files.create_file('foo.txt', 'contents')
p = aws("s3 cp %s s3://%s/ --recursive" %
(files.rootdir, shared_bucket))
assert s3_utils.key_exists(shared_bucket, key_name='foo.txt')
self.assert_no_errors(p)
files.create_file('bar.py', 'contents')
cwd = os.getcwd()
try:
os.chdir(files.rootdir)
p = aws("s3 cp . s3://%s --no-overwrite --recursive" %
shared_bucket)
finally:
os.chdir(cwd)
self.assert_skipped(p, ['foo.txt'])
self.assert_not_skipped(p, ['bar.py'])
self.assert_no_errors(p)
assert s3_utils.key_exists(shared_bucket, key_name='bar.py')

def test_not_overwrite_local_objects(self, files, s3_utils, shared_bucket):
files.create_file('foo.txt', 'contents')
second = files.create_file('bar.py', 'contents')
p = aws("s3 cp %s s3://%s/ --recursive" %
(files.rootdir, shared_bucket))
assert s3_utils.key_exists(shared_bucket, key_name='foo.txt')
assert s3_utils.key_exists(shared_bucket, key_name='bar.py')
self.assert_no_errors(p)
os.remove(second)
cwd = os.getcwd()
try:
os.chdir(files.rootdir)
p = aws("s3 cp s3://%s . --no-overwrite --recursive" %
shared_bucket)
finally:
os.chdir(cwd)
self.assert_skipped(p, ['foo.txt'])
self.assert_not_skipped(p, ['bar.py'])
self.assert_no_errors(p)

def test_combine_not_overwrite_and_exclude_filter(
self, files, s3_utils, shared_bucket):
files.create_file('foo.txt', 'contents')
files.create_file('bar.py', 'contents')
s3_utils.put_object(shared_bucket, key_name='temp/test')
p = aws("s3 cp %s s3://%s/ --recursive" %
(files.rootdir, shared_bucket))
self.assert_no_errors(p)
assert s3_utils.key_exists(shared_bucket, key_name='foo.txt')
assert s3_utils.key_exists(shared_bucket, key_name='bar.py')
cwd = os.getcwd()
try:
os.chdir(files.rootdir)
p = aws("s3 cp s3://%s . --no-overwrite "
"--recursive --exclude test" % shared_bucket)
finally:
os.chdir(cwd)
self.assert_no_errors(p)
self.assert_skipped(p, ['foo.txt', 'bar.py'])
assert 'copy:' not in p.stdout


class TestFileWithSpaces(BaseS3IntegrationTest):
def test_upload_download_file_with_spaces(self, files, shared_bucket):
filename = files.create_file('with space.txt', 'contents')
Expand Down
Loading

0 comments on commit 2bda3e5

Please sign in to comment.