Skip to content

Commit

Permalink
Add CORS support to buckets.
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver committed Nov 4, 2014
1 parent 8b32e34 commit 8289de4
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
53 changes: 53 additions & 0 deletions gcloud/storage/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Bucket(_MetadataMixin):
CUSTOM_METADATA_FIELDS = {
'acl': 'get_acl',
'defaultObjectAcl': 'get_default_object_acl',
'cors': 'get_cors',
}
"""Mapping of field name -> accessor for fields w/ custom accessors."""

Expand Down Expand Up @@ -441,6 +442,58 @@ def make_public(self, recursive=False, future=False):
key.get_acl().all().grant_read()
key.save_acl()

def get_cors(self):
"""Retrieve CORS policies configured for this bucket.
See: http://www.w3.org/TR/cors/ and
https://cloud.google.com/storage/docs/json_api/v1/buckets
:rtype: list(dict)
:returns: A sequence of mappings describing each CORS policy.
Keys include 'max_age', 'methods', 'origins', and
'headers'.
"""
if not self.has_metadata('cors'):
self.reload_metadata()
result = []
for entry in self.metadata.get('cors', ()):
entry = entry.copy()
result.append(entry)
if 'maxAgeSeconds' in entry:
entry['max_age'] = entry.pop('maxAgeSeconds')
if 'method' in entry:
entry['methods'] = entry.pop('method')
if 'origin' in entry:
entry['origins'] = entry.pop('origin')
if 'responseHeader' in entry:
entry['headers'] = entry.pop('responseHeader')
return result

def update_cors(self, entries):
"""Update CORS policies configured for this bucket.
See: http://www.w3.org/TR/cors/ and
https://cloud.google.com/storage/docs/json_api/v1/buckets
:type entries: list(dict)
:param entries: A sequence of mappings describing each CORS policy.
Keys include 'max_age', 'methods', 'origins', and
'headers'.
"""
to_patch = []
for entry in entries:
entry = entry.copy()
to_patch.append(entry)
if 'max_age' in entry:
entry['maxAgeSeconds'] = entry.pop('max_age')
if 'methods' in entry:
entry['method'] = entry.pop('methods')
if 'origins' in entry:
entry['origin'] = entry.pop('origins')
if 'headers' in entry:
entry['responseHeader'] = entry.pop('headers')
self.patch_metadata({'cors': to_patch})


class BucketIterator(Iterator):
"""An iterator listing all buckets.
Expand Down
88 changes: 88 additions & 0 deletions gcloud/storage/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,23 @@ def test_get_metadata_none_set_defaultObjectAcl_miss_clear_default(self):
kw = connection._requested
self.assertEqual(len(kw), 0)

def test_get_metadata_cors_no_default(self):
NAME = 'name'
connection = _Connection()
bucket = self._makeOne(connection, NAME)
self.assertRaises(KeyError, bucket.get_metadata, 'cors')
kw = connection._requested
self.assertEqual(len(kw), 0)

def test_get_metadata_none_set_cors_w_default(self):
NAME = 'name'
connection = _Connection()
bucket = self._makeOne(connection, NAME)
default = object()
self.assertRaises(KeyError, bucket.get_metadata, 'cors', default)
kw = connection._requested
self.assertEqual(len(kw), 0)

def test_get_metadata_miss(self):
NAME = 'name'
before = {'bar': 'Bar'}
Expand Down Expand Up @@ -713,6 +730,77 @@ def get_items_from_response(self, response):
self.assertEqual(kw[1]['path'], '/b/%s/o' % NAME)
self.assertEqual(kw[1]['query_params'], None)

def test_get_cors_eager(self):
NAME = 'name'
CORS_ENTRY = {
'maxAgeSeconds': 1234,
'method': ['OPTIONS', 'GET'],
'origin': ['127.0.0.1'],
'responseHeader': ['Content-Type'],
}
before = {'cors': [CORS_ENTRY, {}]}
connection = _Connection()
bucket = self._makeOne(connection, NAME, before)
entries = bucket.get_cors()
self.assertEqual(len(entries), 2)
self.assertEqual(entries[0]['max_age'], CORS_ENTRY['maxAgeSeconds'])
self.assertEqual(entries[0]['methods'], CORS_ENTRY['method'])
self.assertEqual(entries[0]['origins'], CORS_ENTRY['origin'])
self.assertEqual(entries[0]['headers'], CORS_ENTRY['responseHeader'])
self.assertEqual(entries[1], {})
kw = connection._requested
self.assertEqual(len(kw), 0)

def test_get_cors_lazy(self):
NAME = 'name'
CORS_ENTRY = {
'maxAgeSeconds': 1234,
'method': ['OPTIONS', 'GET'],
'origin': ['127.0.0.1'],
'responseHeader': ['Content-Type'],
}
after = {'cors': [CORS_ENTRY]}
connection = _Connection(after)
bucket = self._makeOne(connection, NAME)
entries = bucket.get_cors()
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0]['max_age'], CORS_ENTRY['maxAgeSeconds'])
self.assertEqual(entries[0]['methods'], CORS_ENTRY['method'])
self.assertEqual(entries[0]['origins'], CORS_ENTRY['origin'])
self.assertEqual(entries[0]['headers'], CORS_ENTRY['responseHeader'])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'GET')
self.assertEqual(kw[0]['path'], '/b/%s' % NAME)
self.assertEqual(kw[0]['query_params'], {'projection': 'noAcl'})

def test_update_cors(self):
NAME = 'name'
CORS_ENTRY = {
'maxAgeSeconds': 1234,
'method': ['OPTIONS', 'GET'],
'origin': ['127.0.0.1'],
'responseHeader': ['Content-Type'],
}
MAPPED = {
'max_age': 1234,
'methods': ['OPTIONS', 'GET'],
'origins': ['127.0.0.1'],
'headers': ['Content-Type'],
}
after = {'cors': [CORS_ENTRY, {}]}
connection = _Connection(after)
bucket = self._makeOne(connection, NAME)
bucket.update_cors([MAPPED, {}])
kw = connection._requested
self.assertEqual(len(kw), 1)
self.assertEqual(kw[0]['method'], 'PATCH')
self.assertEqual(kw[0]['path'], '/b/%s' % NAME)
self.assertEqual(kw[0]['data'], after)
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
entries = bucket.get_cors()
self.assertEqual(entries, [MAPPED, {}])


class TestBucketIterator(unittest2.TestCase):

Expand Down

0 comments on commit 8289de4

Please sign in to comment.