Skip to content

Commit

Permalink
feat: add TwitterSSO to support Twitter (X) login (#139)
Browse files Browse the repository at this point in the history
* feat: add TwitterSSO to support Twitter (X) login
  • Loading branch information
tomasvotava authored Mar 17, 2024
1 parent 2857ab7 commit fd23647
Show file tree
Hide file tree
Showing 21 changed files with 347 additions and 51 deletions.
2 changes: 1 addition & 1 deletion docs/generate_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import mkdocs.config.defaults
import mkdocs.config.defaults # pragma: no cover


SKIPPED_MODULES = ("fastapi_sso.sso", "fastapi_sso")
Expand Down
38 changes: 38 additions & 0 deletions examples/twitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Twitter (X) Login Example
"""

import os
import uvicorn
from fastapi import FastAPI, Request
from fastapi_sso.sso.twitter import TwitterSSO

CLIENT_ID = os.environ["CLIENT_ID"]
CLIENT_SECRET = os.environ["CLIENT_SECRET"]

app = FastAPI()

sso = TwitterSSO(
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
redirect_uri="http://127.0.0.1:5000/auth/callback",
allow_insecure_http=True,
)


@app.get("/auth/login")
async def auth_init():
"""Initialize auth and redirect"""
with sso:
return await sso.get_login_redirect()


@app.get("/auth/callback")
async def auth_callback(request: Request):
"""Verify login"""
with sso:
user = await sso.verify_and_process(request)
return user


if __name__ == "__main__":
uvicorn.run(app="examples.twitter:app", host="127.0.0.1", port=5000)
1 change: 1 addition & 0 deletions fastapi_sso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .sso.naver import NaverSSO
from .sso.notion import NotionSSO
from .sso.spotify import SpotifySSO
from .sso.twitter import TwitterSSO
25 changes: 25 additions & 0 deletions fastapi_sso/pkce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""PKCE-related helper functions"""

import base64
import hashlib
import os
from typing import Tuple


def get_code_verifier(length: int = 96) -> str:
"""Get code verifier for PKCE challenge"""
length = max(43, min(length, 128))
bytes_length = int(length * 3 / 4)
return base64.urlsafe_b64encode(os.urandom(bytes_length)).decode("utf-8").replace("=", "")[:length]


def get_pkce_challenge_pair(verifier_length: int = 96) -> Tuple[str, str]:
"""Get tuple of (verifier, challenge) for PKCE challenge."""
code_verifier = get_code_verifier(verifier_length)
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
.decode("utf-8")
.replace("=", "")
)

return (code_verifier, code_challenge)
62 changes: 59 additions & 3 deletions fastapi_sso/sso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from starlette.requests import Request
from starlette.responses import RedirectResponse

from fastapi_sso.pkce import get_pkce_challenge_pair
from fastapi_sso.state import generate_random_state

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
Expand Down Expand Up @@ -63,6 +66,10 @@ class SSOBase:
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = NotImplemented
scope: List[str] = NotImplemented
additional_headers: Optional[Dict[str, Any]] = None
uses_pkce: bool = False
requires_state: bool = False

_pkce_challenge_length: int = 96

def __init__(
self,
Expand All @@ -79,6 +86,7 @@ def __init__(
self.redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = redirect_uri
self.allow_insecure_http: bool = allow_insecure_http
self._oauth_client: Optional[WebApplicationClient] = None
self._generated_state: Optional[str] = None

if self.allow_insecure_http:
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
Expand All @@ -96,6 +104,9 @@ def __init__(
self._refresh_token: Optional[str] = None
self._id_token: Optional[str] = None
self._state: Optional[str] = None
self._pkce_code_challenge: Optional[str] = None
self._pkce_code_verifier: Optional[str] = None
self._pkce_challenge_method = "S256"

@property
def state(self) -> Optional[str]:
Expand Down Expand Up @@ -236,8 +247,26 @@ async def get_login_url(
redirect_uri = redirect_uri or self.redirect_uri
if redirect_uri is None:
raise ValueError("redirect_uri must be provided, either at construction or request time")
if self.uses_pkce and not all((self._pkce_code_verifier, self._pkce_code_challenge)):
warnings.warn(
f"{self.__class__.__name__!r} uses PKCE and no code was generated yet. "
"Use SSO class as a context manager to get rid of this warning and possible errors."
)
if self.requires_state and not state:
if self._generated_state is None:
warnings.warn(
f"{self.__class__.__name__!r} requires state in the request but none was provided nor "
"generated automatically. Use SSO as a context manager. The login process will most probably fail."
)
state = self._generated_state
request_uri = self.oauth_client.prepare_request_uri(
await self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope, **params
await self.authorization_endpoint,
redirect_uri=redirect_uri,
state=state,
scope=self.scope,
code_challenge=self._pkce_code_challenge,
code_challenge_method=self._pkce_challenge_method,
**params,
)
return request_uri

Expand All @@ -259,8 +288,12 @@ async def get_login_redirect(
Returns:
RedirectResponse: A Starlette response directing to the login page of the OAuth SSO provider.
"""
if self.requires_state and not state:
state = self._generated_state
login_uri = await self.get_login_url(redirect_uri=redirect_uri, params=params, state=state)
response = RedirectResponse(login_uri, 303)
if self.uses_pkce:
response.set_cookie("pkce_code_verifier", str(self._pkce_code_verifier))
return response

async def verify_and_process(
Expand Down Expand Up @@ -291,14 +324,31 @@ async def verify_and_process(
if code is None:
raise SSOLoginError(400, "'code' parameter was not found in callback request")
self._state = request.query_params.get("state")
pkce_code_verifier: Optional[str] = None
if self.uses_pkce:
pkce_code_verifier = request.cookies.get("pkce_code_verifier")
if pkce_code_verifier is None:
warnings.warn(
"PKCE code verifier was not found in the request Cookie. This will probably lead to a login error."
)
return await self.process_login(
code, request, params=params, additional_headers=headers, redirect_uri=redirect_uri
code,
request,
params=params,
additional_headers=headers,
redirect_uri=redirect_uri,
pkce_code_verifier=pkce_code_verifier,
)

def __enter__(self) -> "SSOBase":
self._oauth_client = None
self._refresh_token = None
self._id_token = None
self._state = None
if self.requires_state:
self._generated_state = generate_random_state()
if self.uses_pkce:
self._pkce_code_verifier, self._pkce_code_challenge = get_pkce_challenge_pair(self._pkce_challenge_length)
return self

def __exit__(
Expand All @@ -321,6 +371,7 @@ async def process_login(
params: Optional[Dict[str, Any]] = None,
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
) -> Optional[OpenID]:
"""
Processes login from the callback endpoint to verify the user and request user info endpoint.
Expand All @@ -332,6 +383,7 @@ async def process_login(
params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
additional_headers (Optional[Dict[str, Any]]): Additional headers to be added to all requests.
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
Raises:
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
Expand Down Expand Up @@ -379,8 +431,12 @@ async def process_login(
headers.update(additional_headers)

auth = httpx.BasicAuth(self.client_id, self.client_secret)

if pkce_code_verifier:
params.update({"code_verifier": pkce_code_verifier})

async with httpx.AsyncClient() as session:
response = await session.post(token_url, headers=headers, content=body, auth=auth)
response = await session.post(token_url, headers=headers, content=body, auth=auth, params=params)
content = response.json()
self._refresh_token = content.get("refresh_token")
self._id_token = content.get("id_token")
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class FacebookSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/fitbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase, SSOLoginError

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class FitbitSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class GithubSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class GitlabSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/kakao.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class KakaoSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class LineSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/linkedin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class LinkedInSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/microsoft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class MicrosoftSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/naver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class NaverSSO(SSOBase):
Expand Down
2 changes: 1 addition & 1 deletion fastapi_sso/sso/spotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx
import httpx # pragma: no cover


class SpotifySSO(SSOBase):
Expand Down
35 changes: 35 additions & 0 deletions fastapi_sso/sso/twitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Twitter (X) SSO Oauth Helper class"""

from typing import TYPE_CHECKING, Optional

from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase

if TYPE_CHECKING:
import httpx # pragma: no cover


class TwitterSSO(SSOBase):
"""Class providing login via Twitter SSO"""

provider = "twitter"
scope = ["users.read", "tweet.read"]
uses_pkce = True
requires_state = True

async def get_discovery_document(self) -> DiscoveryDocument:
return {
"authorization_endpoint": "https://twitter.com/i/oauth2/authorize",
"token_endpoint": "https://api.twitter.com/2/oauth2/token",
"userinfo_endpoint": "https://api.twitter.com/2/users/me",
}

async def openid_from_response(self, response: dict, session: Optional["httpx.AsyncClient"] = None) -> OpenID:
first_name, *last_name_parts = response["data"].get("name", "").split(" ")
last_name = " ".join(last_name_parts) if last_name_parts else None
return OpenID(
id=str(response["data"]["id"]),
display_name=response["data"]["username"],
first_name=first_name,
last_name=last_name,
provider=self.provider,
)
10 changes: 10 additions & 0 deletions fastapi_sso/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Helper functions to generate state param"""

import base64
import os


def generate_random_state(length: int = 64) -> str:
"""Generate a url-safe string to use as a state"""
bytes_length = int(length * 3 / 4)
return base64.urlsafe_b64encode(os.urandom(bytes_length)).decode("utf-8")
Loading

0 comments on commit fd23647

Please sign in to comment.