Skip to content

Commit

Permalink
Start fixing mypy errors and drop flask1 support
Browse files Browse the repository at this point in the history
* Fix mypy error

This should allow us to add a CI check for type checking going forward

* Drop flask1 support

Flask is no longer maintaining the 1.x.x branch, and has in fact broken
it with pallets/flask#4456. Since just
installing flask 1.1.4 is no longer sufficient to run our specs and
insure compatibility, I've opted to follow suite and only support flask
2 going forward.

If flask does decide to do further maintenance on 1.1.5 I will consider
reverting this PR and doing whatever updates we need to in order to
insure that they type checking isn't broken.

* Fix type checking imports for python < 3.8
  • Loading branch information
vimalloc authored Feb 18, 2022
1 parent d7ef40e commit 848b4a9
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 55 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,3 @@ jobs:
run: pip install tox
- name: Run Tox
run: tox -e py
- name: Run Tox Flask 1
run: tox -e flask1
19 changes: 14 additions & 5 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
from datetime import timezone
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Type
from typing import Union

from flask import current_app
from flask.json import JSONEncoder
from jwt.algorithms import requires_cryptography

from flask_jwt_extended.typing import ExpiresDelta


class _Config(object):
"""
Expand All @@ -34,7 +39,7 @@ def decode_key(self) -> str:
return self._public_key if self.is_asymmetric else self._secret_key

@property
def token_location(self) -> Iterable[str]:
def token_location(self) -> Sequence[str]:
locations = current_app.config["JWT_TOKEN_LOCATION"]
if isinstance(locations, str):
locations = (locations,)
Expand Down Expand Up @@ -177,12 +182,14 @@ def refresh_csrf_field_name(self) -> str:
return current_app.config["JWT_REFRESH_CSRF_FIELD_NAME"]

@property
def access_expires(self) -> datetime:
def access_expires(self) -> ExpiresDelta:
delta = current_app.config["JWT_ACCESS_TOKEN_EXPIRES"]
if type(delta) is int:
delta = timedelta(seconds=delta)
if delta is not False:
try:
# Basically runtime typechecking. Probably a better way to do
# this with proper type checking
delta + datetime.now(timezone.utc)
except TypeError as e:
err = (
Expand All @@ -192,11 +199,13 @@ def access_expires(self) -> datetime:
return delta

@property
def refresh_expires(self) -> datetime:
def refresh_expires(self) -> ExpiresDelta:
delta = current_app.config["JWT_REFRESH_TOKEN_EXPIRES"]
if type(delta) is int:
delta = timedelta(seconds=delta)
if delta is not False:
# Basically runtime typechecking. Probably a better way to do
# this with proper type checking
try:
delta + datetime.now(timezone.utc)
except TypeError as e:
Expand Down Expand Up @@ -255,7 +264,7 @@ def _private_key(self) -> str:
return key

@property
def cookie_max_age(self) -> int:
def cookie_max_age(self) -> Optional[int]:
# Returns the appropiate value for max_age for flask set_cookies. If
# session cookie is true, return None, otherwise return a number of
# seconds 1 year in the future
Expand All @@ -274,7 +283,7 @@ def error_msg_key(self) -> str:
return current_app.config["JWT_ERROR_MESSAGE_KEY"]

@property
def json_encoder(self) -> JSONEncoder:
def json_encoder(self) -> Type[JSONEncoder]:
return current_app.json_encoder

@property
Expand Down
22 changes: 14 additions & 8 deletions flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any

from flask import jsonify
from flask import Response
from flask.typing import ResponseReturnValue

from flask_jwt_extended.config import config

Expand Down Expand Up @@ -55,15 +55,15 @@ def default_user_identity_callback(userdata: Any) -> Any:

def default_expired_token_callback(
_expired_jwt_header: dict, _expired_jwt_data: dict
) -> Response:
) -> ResponseReturnValue:
"""
By default, if an expired token attempts to access a protected endpoint,
we return a generic error message with a 401 status
"""
return jsonify({config.error_msg_key: "Token has expired"}), HTTPStatus.UNAUTHORIZED


def default_invalid_token_callback(error_string: str) -> Response:
def default_invalid_token_callback(error_string: str) -> ResponseReturnValue:
"""
By default, if an invalid token attempts to access a protected endpoint, we
return the error string for why it is not valid with a 422 status code
Expand All @@ -76,7 +76,7 @@ def default_invalid_token_callback(error_string: str) -> Response:
)


def default_unauthorized_callback(error_string: str) -> Response:
def default_unauthorized_callback(error_string: str) -> ResponseReturnValue:
"""
By default, if a protected endpoint is accessed without a JWT, we return
the error string indicating why this is unauthorized, with a 401 status code
Expand All @@ -86,7 +86,9 @@ def default_unauthorized_callback(error_string: str) -> Response:
return jsonify({config.error_msg_key: error_string}), HTTPStatus.UNAUTHORIZED


def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Response:
def default_needs_fresh_token_callback(
jwt_header: dict, jwt_data: dict
) -> ResponseReturnValue:
"""
By default, if a non-fresh jwt is used to access a ```fresh_jwt_required```
endpoint, we return a general error message with a 401 status code
Expand All @@ -97,7 +99,9 @@ def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Resp
)


def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response:
def default_revoked_token_callback(
jwt_header: dict, jwt_data: dict
) -> ResponseReturnValue:
"""
By default, if a revoked token is used to access a protected endpoint, we
return a general error message with a 401 status code
Expand All @@ -108,7 +112,9 @@ def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response
)


def default_user_lookup_error_callback(_jwt_header: dict, jwt_data: dict) -> Response:
def default_user_lookup_error_callback(
_jwt_header: dict, jwt_data: dict
) -> ResponseReturnValue:
"""
By default, if a user_lookup callback is defined and the callback
function returns None, we return a general error message with a 401
Expand All @@ -128,7 +134,7 @@ def default_token_verification_callback(_jwt_header: dict, _jwt_data: dict) -> b

def default_token_verification_failed_callback(
_jwt_header: dict, _jwt_data: dict
) -> Response:
) -> ResponseReturnValue:
"""
By default, if the user claims verification failed, we return a generic
error message with a 400 status code
Expand Down
4 changes: 3 additions & 1 deletion flask_jwt_extended/internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def has_user_lookup() -> bool:

def user_lookup(*args, **kwargs) -> Any:
jwt_manager = get_jwt_manager()
return jwt_manager._user_lookup_callback(*args, **kwargs)
return jwt_manager._user_lookup_callback and jwt_manager._user_lookup_callback(
*args, **kwargs
)


def verify_token_type(decoded_token: dict, refresh: bool) -> None:
Expand Down
13 changes: 9 additions & 4 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from typing import Any
from typing import Callable
from typing import Optional

import jwt
from flask import Flask
Expand Down Expand Up @@ -40,6 +41,7 @@
from flask_jwt_extended.exceptions import WrongTokenError
from flask_jwt_extended.tokens import _decode_jwt
from flask_jwt_extended.tokens import _encode_jwt
from flask_jwt_extended.typing import ExpiresDelta


class JWTManager(object):
Expand Down Expand Up @@ -75,7 +77,7 @@ def __init__(self, app: Flask = None) -> None:
self._unauthorized_callback = default_unauthorized_callback
self._user_claims_callback = default_additional_claims_callback
self._user_identity_callback = default_user_identity_callback
self._user_lookup_callback = None
self._user_lookup_callback: Optional[Callable] = None
self._user_lookup_error_callback = default_user_lookup_error_callback
self._token_verification_failed_callback = (
default_token_verification_failed_callback
Expand Down Expand Up @@ -478,7 +480,7 @@ def _encode_jwt_from_config(
token_type: str,
claims=None,
fresh: bool = False,
expires_delta: datetime.timedelta = None,
expires_delta: ExpiresDelta = None,
headers=None,
) -> str:
header_overrides = self._jwt_additional_header_callback(identity)
Expand Down Expand Up @@ -538,6 +540,9 @@ def _decode_jwt_from_config(
try:
return _decode_jwt(**kwargs, allow_expired=allow_expired)
except ExpiredSignatureError as e:
e.jwt_header = unverified_headers
e.jwt_data = _decode_jwt(**kwargs, allow_expired=True)
# TODO: If we ever do another breaking change, don't raise this pyjwt
# error directly, instead raise a custom error of ours from this
# error.
e.jwt_header = unverified_headers # type: ignore
e.jwt_data = _decode_jwt(**kwargs, allow_expired=True) # type: ignore
raise
11 changes: 7 additions & 4 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,30 @@
from hmac import compare_digest
from typing import Any
from typing import Iterable
from typing import List
from typing import Type
from typing import Union

import jwt
from flask.json import JSONEncoder

from flask_jwt_extended.exceptions import CSRFError
from flask_jwt_extended.exceptions import JWTDecodeError
from flask_jwt_extended.typing import ExpiresDelta


def _encode_jwt(
algorithm: str,
audience: Union[str, Iterable[str]],
claim_overrides: dict,
csrf: bool,
expires_delta: timedelta,
expires_delta: ExpiresDelta,
fresh: bool,
header_overrides: dict,
identity: Any,
identity_claim_key: str,
issuer: str,
json_encoder: JSONEncoder,
json_encoder: Type[JSONEncoder],
secret: str,
token_type: str,
nbf: bool,
Expand Down Expand Up @@ -65,13 +68,13 @@ def _encode_jwt(
token_data,
secret,
algorithm,
json_encoder=json_encoder,
json_encoder=json_encoder, # type: ignore
headers=header_overrides,
)


def _decode_jwt(
algorithms: Iterable,
algorithms: List,
allow_expired: bool,
audience: Union[str, Iterable[str]],
csrf_value: str,
Expand Down
10 changes: 10 additions & 0 deletions flask_jwt_extended/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sys
from typing import Any
from typing import Union

if sys.version_info >= (3, 8):
from typing import Literal # pragma: no cover
else:
from typing_extensions import Literal # pragma: no cover

ExpiresDelta = Union[Literal[False], Any]
10 changes: 5 additions & 5 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Any
from typing import Optional

import jwt
from flask import _request_ctx_stack
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_jwt_identity() -> Any:
return get_jwt().get(config.identity_claim_key, None)


def get_jwt_request_location() -> str:
def get_jwt_request_location() -> Optional[str]:
"""
In a protected endpoint, this will return the "location" at which the JWT
that is accessing the endpoint was found--e.g., "cookies", "query-string",
Expand All @@ -72,8 +73,7 @@ def get_jwt_request_location() -> str:
The location of the JWT in the current request; e.g., "cookies",
"query-string", "headers", or "json"
"""
location = getattr(_request_ctx_stack.top, "jwt_location", None)
return location
return getattr(_request_ctx_stack.top, "jwt_location", None)


def get_current_user() -> Any:
Expand Down Expand Up @@ -240,15 +240,15 @@ def get_unverified_jwt_headers(encoded_token: str) -> dict:
return jwt.get_unverified_header(encoded_token)


def get_jti(encoded_token: str) -> str:
def get_jti(encoded_token: str) -> Optional[str]:
"""
Returns the JTI (unique identifier) of an encoded JWT
:param encoded_token:
The encoded JWT to get the JTI from.
:return:
The JTI (unique identifier) of a JWT.
The JTI (unique identifier) of a JWT, if it is present.
"""
return decode_token(encoded_token).get("jti")

Expand Down
Loading

0 comments on commit 848b4a9

Please sign in to comment.