Skip to content

Commit

Permalink
make 2.7 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft committed Aug 23, 2021
1 parent f7fa1e0 commit ea06e45
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 57 deletions.
40 changes: 37 additions & 3 deletions sdk/core/azure-core/azure/core/rest/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from typing import AsyncIterator
from multidict import CIMultiDict
from . import HttpRequest, AsyncHttpResponse
from ._helpers import KeysView, ValuesView
from ._helpers_py3 import iter_raw_helper, iter_bytes_helper
from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator

Expand All @@ -55,6 +54,41 @@ def __contains__(self, item):
def __repr__(self):
return f"dict_items({list(self.__iter__())})"

class _KeysView(collections.abc.KeysView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for key, _ in self._items:
yield key

def __contains__(self, key):
for k in self.__iter__():
if key.lower() == k.lower():
return True
return False
def __repr__(self):
return f"dict_keys({list(self.__iter__())})"

class _ValuesView(collections.abc.ValuesView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for _, value in self._items:
yield value

def __contains__(self, value):
for v in self.__iter__():
if value == v:
return True
return False

def __repr__(self):
return f"dict_values({list(self.__iter__())})"


class _CIMultiDict(CIMultiDict):
"""Dictionary with the support for duplicate case-insensitive keys."""
Expand All @@ -64,15 +98,15 @@ def __iter__(self):

def keys(self):
"""Return a new view of the dictionary's keys."""
return KeysView(self.items())
return _KeysView(self.items())

def items(self):
"""Return a new view of the dictionary's items."""
return _ItemsView(super().items())

def values(self):
"""Return a new view of the dictionary's values."""
return ValuesView(self.items())
return _ValuesView(self.items())

def __getitem__(self, key: str) -> str:
return ", ".join(self.getall(key, []))
Expand Down
36 changes: 0 additions & 36 deletions sdk/core/azure-core/azure/core/rest/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,39 +303,3 @@ def decode_to_text(encoding, content):
if encoding:
return content.decode(encoding)
return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(content)

class KeysView(collections.KeysView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for key, _ in self._items:
yield key

def __contains__(self, key):
for k in self.__iter__():
if key.lower() == k.lower():
return True
return False

def __repr__(self):
return "dict_keys()".format({list(self.__iter__())})

class ValuesView(collections.ValuesView):
def __init__(self, items):
super().__init__(items)
self._items = items

def __iter__(self):
for _, value in self._items:
yield value

def __contains__(self, value):
for v in self.__iter__():
if value == v:
return True
return False

def __repr__(self):
return "dict_values({})".format(list(self.__iter__()))
14 changes: 0 additions & 14 deletions sdk/core/azure-core/azure/core/rest/_requests_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError
from ._rest import _HttpResponseBase, HttpResponse
from ._helpers import KeysView, ValuesView
from ..pipeline.transport._requests_basic import StreamDownloadGenerator

class _ItemsView(collections.ItemsView):
Expand All @@ -54,19 +53,6 @@ def items(self):
"""Return a new view of the dictionary's items."""
return _ItemsView(self)

def values(self):
"""Return a new view of the dictionary's values.
Need to overwrite because default behavior is not mutable for 2.7
"""
return ValuesView(self.items())

def keys(self):
"""Return a new view of the dictionary's keys.
Need to overwrite because default behavior is not mutable for 2.7
"""
return KeysView(self.items())

if TYPE_CHECKING:
from typing import Iterator, Optional
Expand Down
26 changes: 22 additions & 4 deletions sdk/core/azure-core/tests/testserver_tests/test_rest_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import sys
from typing import ValuesView
import pytest

# NOTE: These tests are heavily inspired from the httpx test suite: https://github.com/encode/httpx/tree/master/tests
Expand Down Expand Up @@ -68,13 +69,18 @@ def test_headers_response_keys(get_response_headers):
# basically want to make sure this behaves like dict {"a": "123, 456", "b": "789"}
ref_dict = {"a": "123, 456", "b": "789"}
assert list(h.keys()) == list(ref_dict.keys())
assert repr(h.keys()) == "KeysView({'a': '123, 456', 'b': '789'})"
if sys.version_info < (3, 0):
assert repr(h.keys()) == "['a', 'b']"
else:
assert repr(h.keys()) == "KeysView({'a': '123, 456', 'b': '789'})"
assert "a" in h.keys()
assert "A" in h.keys()
assert "b" in h.keys()
assert "B" in h.keys()
assert set(h.keys()) == set(ref_dict.keys())

@pytest.mark.skipif(sys.version_info < (3, 0),
reason="In 2.7, .keys() are not mutable")
def test_headers_response_keys_mutability(get_response_headers):
h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers"))
# test mutability
before_mutation_keys = h.keys()
h['c'] = '000'
Expand All @@ -85,11 +91,18 @@ def test_headers_response_values(get_response_headers):
# basically want to make sure this behaves like dict {"a": "123, 456", "b": "789"}
ref_dict = {"a": "123, 456", "b": "789"}
assert list(h.values()) == list(ref_dict.values())
assert repr(h.values()) == "ValuesView({'a': '123, 456', 'b': '789'})"
if sys.version_info < (3, 0):
assert repr(h.values()) == "['123, 456', '789']"
else:
assert repr(h.values()) == "ValuesView({'a': '123, 456', 'b': '789'})"
assert '123, 456' in h.values()
assert '789' in h.values()
assert set(h.values()) == set(ref_dict.values())

@pytest.mark.skipif(sys.version_info < (3, 0),
reason="In 2.7, .values() are not mutable")
def test_headers_response_values_mutability(get_response_headers):
h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers"))
# test mutability
before_mutation_values = h.values()
h['c'] = '000'
Expand All @@ -109,6 +122,11 @@ def test_headers_response_items(get_response_headers):
assert ("B", '789') in h.items()
assert set(h.items()) == set(ref_dict.items())


@pytest.mark.skipif(sys.version_info < (3, 0),
reason="In 2.7, .items() are not mutable")
def test_headers_response_items_mutability(get_response_headers):
h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers"))
# test mutability
before_mutation_items = h.items()
h['c'] = '000'
Expand Down

0 comments on commit ea06e45

Please sign in to comment.