diff --git a/README.md b/README.md index 3566837..0b00bba 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ for Python projects, e.g. when you need a high-load queue consumer or high-load Supports Python versions 3.8, 3.9, 3.10, 3.11, 3.12. +Supported and tested Amazon-like SQS providers: Amazon, VK Cloud. + ---- ## Why aiosqs? diff --git a/aiosqs/__init__.py b/aiosqs/__init__.py index e9794ca..90a1780 100644 --- a/aiosqs/__init__.py +++ b/aiosqs/__init__.py @@ -10,4 +10,4 @@ SendMessageResponse, ) -VERSION = "1.0.4" +VERSION = "1.0.5" diff --git a/aiosqs/client.py b/aiosqs/client.py index e4a62ba..953ef69 100644 --- a/aiosqs/client.py +++ b/aiosqs/client.py @@ -6,7 +6,7 @@ import datetime import urllib.parse from logging import getLogger -from typing import Dict, Optional, List, Union +from typing import Dict, Optional, List, Union, Callable, NamedTuple import aiohttp @@ -18,6 +18,11 @@ default_logger = getLogger(__name__) +class SignedRequest(NamedTuple): + headers: Dict + querystring: str + + class SQSClient: algorithm = "AWS4-HMAC-SHA256" default_timeout_sec = 10 @@ -31,6 +36,7 @@ def __init__( timeout_sec: Optional[int] = None, logger: Optional[LoggerType] = None, verify_ssl: Optional[bool] = None, + quote_via: Optional[Callable] = None, ): self.service_name = "sqs" self.region_name = region_name @@ -46,6 +52,11 @@ def __init__( self.timeout = aiohttp.ClientTimeout(total=timeout_sec or self.default_timeout_sec) self.session = aiohttp.ClientSession(timeout=self.timeout) + # It's possible to have differen quoting logic for different SQS providers. + # By default Amazon SQS uses `urllib.parse.quote`, so no extra customizations are required. + # Related issue: https://github.com/d3QUone/aiosqs/issues/13 + self.quote_via = quote_via or urllib.parse.quote + async def close(self): await self.session.close() # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown @@ -57,7 +68,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() - def get_headers(self, params: Dict): + def build_signed_request(self, params: Dict) -> SignedRequest: # Create a date for headers and the credential string t = datetime.datetime.utcnow() amz_date = t.strftime("%Y%m%dT%H%M%SZ") @@ -69,7 +80,7 @@ def get_headers(self, params: Dict): # Create the canonical query string. Important notes: # - Query string values must be URL-encoded (space=%20). # - The parameters must be sorted by name. - canonical_querystring = urllib.parse.urlencode(list(sorted(params.items()))) + canonical_querystring = urllib.parse.urlencode(query=list(sorted(params.items())), quote_via=self.quote_via) # Create the canonical headers and signed headers. canonical_headers = f"host:{self.host}" + "\n" + f"x-amz-date:{amz_date}" + "\n" @@ -116,21 +127,25 @@ def get_headers(self, params: Dict): # The request can include any headers, but MUST include "host", "x-amz-date", # and (for this scenario) "Authorization". "host" and "x-amz-date" must # be included in the canonical_headers and signed_headers. Order here is not significant. - return { + headers = { "x-amz-date": amz_date, "Authorization": authorization_header, "content-type": "application/x-www-form-urlencoded", } + return SignedRequest( + headers=headers, + querystring=canonical_querystring, + ) async def request(self, params: Dict) -> Union[Dict, List, None]: params["Version"] = "2012-11-05" - headers = self.get_headers(params=params) + signed_request = self.build_signed_request(params=params) + url = f"{self.endpoint_url}?{signed_request.querystring}" try: response = await self.session.get( - url=self.endpoint_url, - headers=headers, - params=params, + url=url, + headers=signed_request.headers, verify_ssl=self.verify_ssl, ) except Exception as e: diff --git a/aiosqs/tests/test_client.py b/aiosqs/tests/test_client.py index 62dc9be..b77bf8a 100644 --- a/aiosqs/tests/test_client.py +++ b/aiosqs/tests/test_client.py @@ -1,8 +1,10 @@ import unittest import re import logging +import urllib.parse import ddt +from freezegun import freeze_time from aioresponses import aioresponses from aiosqs.exceptions import SQSErrorResponse @@ -11,12 +13,10 @@ @ddt.ddt(testNameFormat=ddt.TestNameFormat.INDEX_ONLY) -class ClientTestCase(unittest.IsolatedAsyncioTestCase): +class DefaultClientTestCase(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - await super().asyncSetUp() - - logger = logging.getLogger(__name__) - logger.setLevel(logging.CRITICAL) + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.CRITICAL) self.client = SQSClient( aws_access_key_id="access_key_id", @@ -24,12 +24,36 @@ async def asyncSetUp(self): region_name="us-west-2", host="mocked_amazon_host.com", timeout_sec=0, - logger=logger, + logger=self.logger, ) async def asyncTearDown(self): await self.client.close() + async def test_signature_with_quote_via(self): + params = { + "Action": "SendMessage", + "DelaySeconds": 0, + "MessageBody": "a b c d", + "QueueUrl": "http://host.com/internal/tests", + "Version": "2012-11-05", + } + with freeze_time("2022-03-07T11:30:00.0000"): + signed_request = self.client.build_signed_request(params=params) + + self.assertEqual( + signed_request.headers, + { + "x-amz-date": "20220307T113000Z", + "Authorization": "AWS4-HMAC-SHA256 Credential=access_key_id/20220307/us-west-2/sqs/aws4_request, SignedHeaders=host;x-amz-date, Signature=7d7ae7f85d3175f61e5256ed560c7b284491f767b9c352d1231f92ec04043d8e", + "content-type": "application/x-www-form-urlencoded", + }, + ) + self.assertEqual( + signed_request.querystring, + "Action=SendMessage&DelaySeconds=0&MessageBody=a%20%20%20%20%20b%20%20%20%20c%20%20%20%20%20d&QueueUrl=http%3A%2F%2Fhost.com%2Finternal%2Ftests&Version=2012-11-05", + ) + @aioresponses() async def test_is_context_manager(self, mock): mock.get( @@ -67,3 +91,44 @@ async def test_invalid_auth_keys(self, fixture_name: str, error_message: str, mo self.assertEqual(exception.error.type, "Sender") self.assertEqual(exception.error.code, "InvalidClientTokenId") self.assertEqual(exception.error.message, error_message) + + +class VKClientTestCase(DefaultClientTestCase): + async def asyncSetUp(self): + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.CRITICAL) + + self.client = SQSClient( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key", + region_name="us-west-2", + host="mocked_amazon_host.com", + timeout_sec=0, + logger=self.logger, + quote_via=urllib.parse.quote_plus, + ) + + async def test_signature_with_quote_via(self): + params = { + "Action": "SendMessage", + "DelaySeconds": 0, + "MessageBody": "a b c d", + "QueueUrl": "http://host.com/internal/tests", + "Version": "2012-11-05", + } + + with freeze_time("2022-03-07T11:30:00.0000"): + signed_request = self.client.build_signed_request(params=params) + + self.assertEqual( + signed_request.headers, + { + "x-amz-date": "20220307T113000Z", + "Authorization": "AWS4-HMAC-SHA256 Credential=access_key_id/20220307/us-west-2/sqs/aws4_request, SignedHeaders=host;x-amz-date, Signature=0c36e0d3f62bd7ecb7e78ffe09fbd1224b7f850f3b4f13c7fc82e516fc7f2c57", + "content-type": "application/x-www-form-urlencoded", + }, + ) + self.assertEqual( + signed_request.querystring, + "Action=SendMessage&DelaySeconds=0&MessageBody=a+++++b++++c+++++d&QueueUrl=http%3A%2F%2Fhost.com%2Finternal%2Ftests&Version=2012-11-05", + ) diff --git a/e2e/test_e2e.py b/e2e/test_e2e.py index 55c02d3..eb58666 100644 --- a/e2e/test_e2e.py +++ b/e2e/test_e2e.py @@ -1,5 +1,6 @@ import unittest import logging +from urllib.parse import quote_plus from dotenv import dotenv_values @@ -29,6 +30,7 @@ async def asyncSetUp(self): host=self.host, verify_ssl=False, logger=logger, + quote_via=quote_plus, ) async def asyncTearDown(self): diff --git a/pyproject.toml b/pyproject.toml index b38b78e..0fdd085 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ Source = "https://github.com/d3QUone/aiosqs" [tool.poetry] name = "aiosqs" -version = "1.0.4" +version = "1.0.5" description = "Python asynchronous and lightweight SQS client." authors = ["Vladimir Kasatkin "] license = "MIT"