Skip to content

Commit

Permalink
Fix data-corruption issue with s3boto and s3boto3 multipart uploads (j…
Browse files Browse the repository at this point in the history
…schneier#504)

* fix for files with size more than buffer size

fixes jschneier#160

* fix(s3bot3): spool buffer file to end of all uploaded parts after each
             part upload.

             Fixes jschneier#364, similar issue to jschneier#160 for s3boto3. Inspired by
             vinayinvicible's fix for jschneier#160.

* Fix style issues flagged by flake8

* `At least two spaces before inline comment (E261)`
* `Imports are incorrectly sorted.`

* Fix s3boto3 test incompatibility with Python 3.4
  • Loading branch information
jnm authored and nitely committed Jul 30, 2018
1 parent ceca661 commit e25a0ce
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 6 deletions.
15 changes: 12 additions & 3 deletions storages/backends/s3boto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self, name, mode, storage, buffer_size=None):
if buffer_size is not None:
self.buffer_size = buffer_size
self._write_counter = 0
# file position of the latest part file uploaded
self._last_part_pos = 0

@property
def size(self):
Expand Down Expand Up @@ -123,10 +125,14 @@ def write(self, content, *args, **kwargs):
reduced_redundancy=self._storage.reduced_redundancy,
encrypt_key=self._storage.encryption,
)
if self.buffer_size <= self._buffer_file_size:
if self.buffer_size <= self._file_part_size:
self._flush_write_buffer()
return super(S3BotoStorageFile, self).write(force_bytes(content), *args, **kwargs)

@property
def _file_part_size(self):
return self._buffer_file_size - self._last_part_pos

@property
def _buffer_file_size(self):
pos = self.file.tell()
Expand All @@ -136,12 +142,15 @@ def _buffer_file_size(self):
return length

def _flush_write_buffer(self):
if self._buffer_file_size:
if self._file_part_size:
self._write_counter += 1
self.file.seek(0)
pos = self.file.tell()
self.file.seek(self._last_part_pos)
headers = self._storage.headers.copy()
self._multipart.upload_part_from_file(
self.file, self._write_counter, headers=headers)
self.file.seek(pos)
self._last_part_pos = self._buffer_file_size

def close(self):
if self._is_dirty:
Expand Down
13 changes: 11 additions & 2 deletions storages/backends/s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, name, mode, storage, buffer_size=None):
if buffer_size is not None:
self.buffer_size = buffer_size
self._write_counter = 0
# file position of the latest part file
self._last_part_pos = 0

@property
def size(self):
Expand Down Expand Up @@ -121,10 +123,14 @@ def write(self, content):
if self._storage.encryption:
parameters['ServerSideEncryption'] = 'AES256'
self._multipart = self.obj.initiate_multipart_upload(**parameters)
if self.buffer_size <= self._buffer_file_size:
if self.buffer_size <= self._file_part_size:
self._flush_write_buffer()
return super(S3Boto3StorageFile, self).write(force_bytes(content))

@property
def _file_part_size(self):
return self._buffer_file_size - self._last_part_pos

@property
def _buffer_file_size(self):
pos = self.file.tell()
Expand All @@ -139,9 +145,12 @@ def _flush_write_buffer(self):
"""
if self._buffer_file_size:
self._write_counter += 1
self.file.seek(0)
pos = self.file.tell()
self.file.seek(self._last_part_pos)
part = self._multipart.Part(self._write_counter)
part.upload(Body=self.file.read())
self.file.seek(pos)
self._last_part_pos = self._buffer_file_size

def close(self):
if self._is_dirty:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_s3boto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mock

import datetime
import os

from boto.exception import S3ResponseError
from boto.s3.key import Key
Expand Down Expand Up @@ -306,3 +307,39 @@ def test_get_modified_time(self, getkey):
self.assertEqual(modtime,
tz.make_naive(tz.make_aware(
datetime.datetime.strptime(utcnow, ISO8601), tz.utc)))

def test_file_greater_than_5MB(self):
name = 'test_storage_save.txt'
content = ContentFile('0' * 10 * 1024 * 1024)

# Set the encryption flag used for multipart uploads
self.storage.encryption = True
# Set the ACL header used when creating/writing data.
self.storage.bucket.connection.provider.acl_header = 'x-amz-acl'
# Set the mocked key's bucket
self.storage.bucket.get_key.return_value.bucket = self.storage.bucket
# Set the name of the mock object
self.storage.bucket.get_key.return_value.name = name

def get_upload_file_size(fp):
pos = fp.tell()
fp.seek(0, os.SEEK_END)
length = fp.tell() - pos
fp.seek(pos)
return length

def upload_part_from_file(fp, part_num, *args, **kwargs):
if len(file_part_size) != part_num:
file_part_size.append(get_upload_file_size(fp))

file_part_size = []
f = self.storage.open(name, 'w')

# initiate the multipart upload
f.write('')
f._multipart.upload_part_from_file = upload_part_from_file
for chunk in content.chunks():
f.write(chunk)
f.close()

assert content.size == sum(file_part_size)
105 changes: 104 additions & 1 deletion tests/test_s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_storage_exists_false(self):
def test_storage_exists_doesnt_create_bucket(self):
with mock.patch.object(self.storage, '_get_or_create_bucket') as method:
self.storage.exists('file.txt')
method.assert_not_called()
self.assertFalse(method.called)

def test_storage_delete(self):
self.storage.delete("path/to/file.txt")
Expand Down Expand Up @@ -406,3 +406,106 @@ def thread_storage_connection():

# Connection for each thread needs to be unique
self.assertIsNot(connections[0], connections[1])

def test_file_greater_than_5mb(self):
"""
test writing a large file in a single part so that the buffer is flushed
only on close
"""
name = 'test_storage_save.txt'
content = '0' * 10 * 1024 * 1024

# set the encryption flag used for multipart uploads
self.storage.encryption = True
self.storage.reduced_redundancy = True
self.storage.default_acl = 'public-read'

f = self.storage.open(name, 'w')
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# set the name of the mock object
obj.key = name
multipart = obj.initiate_multipart_upload.return_value
part = multipart.Part.return_value
multipart.parts.all.return_value = [mock.MagicMock(e_tag='123', part_number=1)]

with mock.patch.object(f, '_flush_write_buffer') as method:
f.write(content)
self.assertFalse(method.called) # buffer not flushed on write

assert f._file_part_size == len(content)
obj.initiate_multipart_upload.assert_called_with(
ACL='public-read',
ContentType='text/plain',
ServerSideEncryption='AES256',
StorageClass='REDUCED_REDUNDANCY'
)

with mock.patch.object(f, '_flush_write_buffer', wraps=f._flush_write_buffer) as method:
f.close()
method.assert_called_with() # buffer flushed on close
multipart.Part.assert_called_with(1)
part.upload.assert_called_with(Body=content.encode('utf-8'))
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [{'ETag': '123', 'PartNumber': 1}]})

def test_file_write_after_exceeding_5mb(self):
"""
test writing a large file in two parts so that the buffer is flushed
on write and on close
"""
name = 'test_storage_save.txt'
content1 = '0' * 5 * 1024 * 1024
content2 = '0'

# set the encryption flag used for multipart uploads
self.storage.encryption = True
self.storage.reduced_redundancy = True
self.storage.default_acl = 'public-read'

f = self.storage.open(name, 'w')
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# set the name of the mock object
obj.key = name
multipart = obj.initiate_multipart_upload.return_value
part = multipart.Part.return_value
multipart.parts.all.return_value = [
mock.MagicMock(e_tag='123', part_number=1),
mock.MagicMock(e_tag='456', part_number=2)
]

with mock.patch.object(f, '_flush_write_buffer', wraps=f._flush_write_buffer) as method:
f.write(content1)
self.assertFalse(method.called) # buffer doesn't get flushed on the first write
assert f._file_part_size == len(content1) # file part size is the size of what's written
assert f._last_part_pos == 0 # no parts added, so last part stays at 0
f.write(content2)
method.assert_called_with() # second write flushes buffer
multipart.Part.assert_called_with(1) # first part created
part.upload.assert_called_with(Body=content1.encode('utf-8')) # first part is uploaded
assert f._last_part_pos == len(content1) # buffer spools to end of content1
assert f._buffer_file_size == len(content1) + len(content2) # _buffer_file_size is total written
assert f._file_part_size == len(content2) # new part is size of content2

obj.initiate_multipart_upload.assert_called_with(
ACL='public-read',
ContentType='text/plain',
ServerSideEncryption='AES256',
StorageClass='REDUCED_REDUNDANCY'
)
# save the internal file before closing
f.close()
multipart.Part.assert_called_with(2)
part.upload.assert_called_with(Body=content2.encode('utf-8'))
multipart.complete.assert_called_once_with(
MultipartUpload={'Parts': [
{
'ETag': '123',
'PartNumber': 1
},
{
'ETag': '456',
'PartNumber': 2
}
]})

0 comments on commit e25a0ce

Please sign in to comment.