Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Issue #708: Save files with string content #911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ __pycache__
.vscode/
.pytest_cache/
venv/
.venv/

dist/
docs/_build
Expand Down
6 changes: 6 additions & 0 deletions storages/backends/s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from storages.base import BaseStorage
from storages.compress import CompressedFileMixin
from storages.compress import CompressStorageMixin
from storages.utils import ReadBytesWrapper
from storages.utils import check_location
from storages.utils import clean_name
from storages.utils import get_available_overwrite_name
Expand Down Expand Up @@ -432,6 +433,11 @@ def _save(self, name, content):

if is_seekable(content):
content.seek(0, os.SEEK_SET)

# wrap content so read() always returns bytes. This is required for passing it
# to obj.upload_fileobj() or self._compress_content()
content = ReadBytesWrapper(content)
Copy link
Contributor Author

@LincolnPuzey LincolnPuzey Jul 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core part of the fix. This approach has the benefit of not making any extra read() or seek() calls to the file object.


if (self.gzip and
params['ContentType'] in self.gzip_content_types and
'ContentEncoding' not in params):
Expand Down
33 changes: 33 additions & 0 deletions storages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import SuspiciousFileOperation
from django.core.files.utils import FileProxyMixin
from django.utils.encoding import force_bytes


Expand Down Expand Up @@ -125,3 +126,35 @@ def get_available_overwrite_name(name, max_length):

def is_seekable(file_object):
return not hasattr(file_object, 'seekable') or file_object.seekable()


class ReadBytesWrapper(FileProxyMixin):
"""
A wrapper for a file-like object, that makes read() always returns bytes.
"""
def __init__(self, file, encoding=None):
"""
:param file: The file-like object to wrap.
:param encoding: Specify the encoding to use when file.read() returns strings.
If not provided will default to file.encoding, of if that's not available,
to utf-8.
"""
self.file = file
self._encoding = (
encoding
or getattr(file, "encoding", None)
or "utf-8"
)

def read(self, *args, **kwargs):
content = self.file.read(*args, **kwargs)

if not isinstance(content, bytes):
content = content.encode(self._encoding)
return content

def close(self):
self.file.close()

def readable(self):
return True
3 changes: 3 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
SECRET_KEY = 'hailthesunshine'

USE_TZ = True

# the following test settings are required for moto to work.
AWS_STORAGE_BUCKET_NAME = "test-bucket"
1 change: 1 addition & 0 deletions tests/test_files/windows-1252-encoded.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
™€‰
128 changes: 109 additions & 19 deletions tests/test_s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
from unittest import skipIf
from urllib.parse import urlparse

import boto3
import boto3.s3.transfer
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.files.base import ContentFile
from django.core.files.base import File
from django.test import TestCase
from django.test import override_settings
from django.utils.timezone import is_aware
from moto import mock_s3

from storages.backends import s3boto3
from tests.utils import NonSeekableContentFile
Expand All @@ -32,11 +35,11 @@ def setUp(self):
self.storage._connections.connection = mock.MagicMock()

def test_s3_session(self):
settings.AWS_S3_SESSION_PROFILE = "test_profile"
with mock.patch('boto3.Session') as mock_session:
storage = s3boto3.S3Boto3Storage()
_ = storage.connection
mock_session.assert_called_once_with(profile_name="test_profile")
with override_settings(AWS_S3_SESSION_PROFILE="test_profile"):
with mock.patch('boto3.Session') as mock_session:
storage = s3boto3.S3Boto3Storage()
_ = storage.connection
mock_session.assert_called_once_with(profile_name="test_profile")

def test_pickle_with_bucket(self):
"""
Expand Down Expand Up @@ -94,7 +97,7 @@ def test_storage_save(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
},
Expand All @@ -112,7 +115,7 @@ def test_storage_save_non_seekable(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
},
Expand All @@ -131,7 +134,7 @@ def test_storage_save_with_default_acl(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
'ACL': 'private',
Expand All @@ -152,7 +155,7 @@ def test_storage_object_parameters_not_overwritten_by_default(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
'ACL': 'private',
Expand All @@ -172,7 +175,7 @@ def test_content_type(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'image/jpeg',
},
Expand All @@ -187,8 +190,8 @@ def test_storage_save_gzipped(self):
content = ContentFile("I am gzip'd")
self.storage.save(name, content)
obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
obj.upload_fileobj.assert_called_once_with(
mock.ANY,
ExtraArgs={
'ContentType': 'application/octet-stream',
'ContentEncoding': 'gzip',
Expand All @@ -208,7 +211,7 @@ def get_object_parameters(name):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
"ContentType": "application/gzip",
},
Expand All @@ -223,8 +226,8 @@ def test_storage_save_gzipped_non_seekable(self):
content = NonSeekableContentFile("I am gzip'd")
self.storage.save(name, content)
obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
obj.upload_fileobj.assert_called_once_with(
mock.ANY,
ExtraArgs={
'ContentType': 'application/octet-stream',
'ContentEncoding': 'gzip',
Expand Down Expand Up @@ -287,7 +290,7 @@ def test_compress_content_len(self):
Test that file returned by _compress_content() is readable.
"""
self.storage.gzip = True
content = ContentFile("I should be gzip'd")
content = ContentFile(b"I should be gzip'd")
content = self.storage._compress_content(content)
self.assertTrue(len(content.read()) > 0)

Expand Down Expand Up @@ -569,7 +572,7 @@ def test_storage_listdir_base(self):
self.storage._connections.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir('')
paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='')
paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='')

self.assertEqual(dirs, ['some', 'other'])
self.assertEqual(files, ['2.txt', '4.txt'])
Expand All @@ -594,7 +597,7 @@ def test_storage_listdir_subdir(self):
self.storage._connections.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir('some/')
paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='some/')
paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='some/')

self.assertEqual(dirs, ['path'])
self.assertEqual(files, ['2.txt'])
Expand All @@ -615,7 +618,7 @@ def test_storage_listdir_empty(self):
self.storage._connections.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir('dir/')
paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='dir/')
paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='dir/')

self.assertEqual(dirs, [])
self.assertEqual(files, [])
Expand Down Expand Up @@ -865,3 +868,90 @@ def test_closed(self):
with self.subTest("is True after close"):
f.close()
self.assertTrue(f.closed)


@mock_s3
class S3Boto3StorageTestsWithMoto(TestCase):
"""
Using mock_s3 as a class decorator automatically decorates methods,
but NOT classmethods or staticmethods.
"""

def setUp(cls):
super().setUp()

cls.storage = s3boto3.S3Boto3Storage()
cls.bucket = cls.storage.connection.Bucket(settings.AWS_STORAGE_BUCKET_NAME)
cls.bucket.create()

def test_save_bytes_file(self):
self.storage.save("bytes_file.txt", File(io.BytesIO(b"foo1")))

self.assertEqual(
b"foo1",
self.bucket.Object("bytes_file.txt").get()['Body'].read(),
)

def test_save_string_file(self):
self.storage.save("string_file.txt", File(io.StringIO("foo2")))

self.assertEqual(
b"foo2",
self.bucket.Object("string_file.txt").get()['Body'].read(),
)

def test_save_bytes_content_file(self):
self.storage.save("bytes_content.txt", ContentFile(b"foo3"))

self.assertEqual(
b"foo3",
self.bucket.Object("bytes_content.txt").get()['Body'].read(),
)

def test_save_string_content_file(self):
self.storage.save("string_content.txt", ContentFile("foo4"))

self.assertEqual(
b"foo4",
self.bucket.Object("string_content.txt").get()['Body'].read(),
)

def test_content_type_guess(self):
"""
Test saving a file where the ContentType is guessed from the filename.
"""
name = 'test_image.jpg'
content = ContentFile(b'data')
content.content_type = None
self.storage.save(name, content)

s3_object_fetched = self.bucket.Object(name).get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
self.assertEqual(s3_object_fetched["ContentType"], "image/jpeg")

def test_content_type_attribute(self):
"""
Test saving a file with a custom content type attribute.
"""
content = ContentFile(b'data')
content.content_type = "test/foo"
self.storage.save("test_file", content)

s3_object_fetched = self.bucket.Object("test_file").get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
self.assertEqual(s3_object_fetched["ContentType"], "test/foo")

def test_content_type_not_detectable(self):
"""
Test saving a file with no detectable content type.
"""
content = ContentFile(b'data')
content.content_type = None
self.storage.save("test_file", content)

s3_object_fetched = self.bucket.Object("test_file").get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
self.assertEqual(
s3_object_fetched["ContentType"],
s3boto3.S3Boto3Storage.default_content_type,
)
Loading
Loading