Skip to content

Commit

Permalink
pydantic v1 & v2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
olirice committed Aug 16, 2023
1 parent c2ed950 commit 45456c8
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 59 deletions.
5 changes: 3 additions & 2 deletions gotrue/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY
from ..exceptions import APIError
from ..helpers import model_dump, model_validate
from ..types import (
AuthChangeEvent,
CookieOptions,
Expand Down Expand Up @@ -560,7 +561,7 @@ async def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.model_validate(session_raw)
session = model_validate(Session, session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -628,7 +629,7 @@ async def _save_session(self, *, session: Session) -> None:
await self._persist_session(session=session)

async def _persist_session(self, *, session: Session) -> None:
data = {"session": session.model_dump(), "expires_at": session.expires_at}
data = {"session": model_dump(session), "expires_at": session.expires_at}
await self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

async def _remove_session(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions gotrue/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Union

from ..helpers import parse_link_response, parse_user_response
from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import AsyncClient
from ..types import (
AdminUserAttributes,
Expand Down Expand Up @@ -109,7 +110,7 @@ async def list_users(self) -> List[User]:
return await self._request(
"GET",
"admin/users",
xform=lambda data: [User.model_validate(user) for user in data["users"]]
xform=lambda data: [model_validate(User, user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +162,7 @@ async def _list_factors(
return await self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminListFactorsResponse),
)

async def _delete_factor(
Expand All @@ -171,5 +172,5 @@ async def _delete_factor(
return await self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse),
)
4 changes: 2 additions & 2 deletions gotrue/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..helpers import handle_exception
from ..helpers import handle_exception, model_dump
from ..http_clients import AsyncClient

T = TypeVar("T")
Expand Down Expand Up @@ -108,7 +108,7 @@ async def _request(
url,
headers=headers,
params=query,
json=body.model_dump() if isinstance(body, BaseModel) else body,
json=model_dump(body) if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
24 changes: 16 additions & 8 deletions gotrue/_async/gotrue_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
Expand All @@ -20,7 +21,14 @@
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response
from ..helpers import (
decode_jwt_payload,
model_dump,
model_dump_json,
model_validate,
parse_auth_response,
parse_user_response,
)
from ..http_clients import AsyncClient
from ..timer import Timer
from ..types import (
Expand Down Expand Up @@ -531,7 +539,7 @@ async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.model_validate,
xform=partial(model_validate, AuthMFAEnrollResponse),
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -545,7 +553,7 @@ async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeRespon
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.model_validate,
xform=partial(model_validate, AuthMFAChallengeResponse),
)

async def _challenge_and_verify(
Expand Down Expand Up @@ -574,9 +582,9 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.model_validate,
xform=partial(model_validate, AuthMFAVerifyResponse),
)
session = Session.model_validate(response.model_dump())
session = model_validate(Session, model_dump(response))
await self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -589,7 +597,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.model_validate,
xform=partial(AuthMFAUnenrollResponse, model_validate),
)

async def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -751,7 +759,7 @@ async def _save_session(self, session: Session) -> None:
value = (expire_in - refresh_duration_before_expires) * 1000
await self._start_auto_refresh_token(value)
if self._persist_session and session.expires_at:
await self._storage.set_item(self._storage_key, session.model_dump_json())
await self._storage.set_item(self._storage_key, model_dump_json(session))

async def _start_auto_refresh_token(self, value: float) -> None:
if self._refresh_token_timer:
Expand Down Expand Up @@ -808,7 +816,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.model_validate(data)
return model_validate(Session, data)
except Exception:
return None

Expand Down
5 changes: 3 additions & 2 deletions gotrue/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY
from ..exceptions import APIError
from ..helpers import model_dump, model_validate
from ..types import (
AuthChangeEvent,
CookieOptions,
Expand Down Expand Up @@ -556,7 +557,7 @@ def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.model_validate(session_raw)
session = model_validate(Session, session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -620,7 +621,7 @@ def _save_session(self, *, session: Session) -> None:
self._persist_session(session=session)

def _persist_session(self, *, session: Session) -> None:
data = {"session": session.model_dump(), "expires_at": session.expires_at}
data = {"session": model_dump(session), "expires_at": session.expires_at}
self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

def _remove_session(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions gotrue/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Union

from ..helpers import parse_link_response, parse_user_response
from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import SyncClient
from ..types import (
AdminUserAttributes,
Expand Down Expand Up @@ -109,7 +110,7 @@ def list_users(self) -> List[User]:
return self._request(
"GET",
"admin/users",
xform=lambda data: [User.model_validate(user) for user in data["users"]]
xform=lambda data: [model_validate(User, user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +162,7 @@ def _list_factors(
return self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminListFactorsResponse),
)

def _delete_factor(
Expand All @@ -171,5 +172,5 @@ def _delete_factor(
return self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse),
)
4 changes: 2 additions & 2 deletions gotrue/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..helpers import handle_exception
from ..helpers import handle_exception, model_dump
from ..http_clients import SyncClient

T = TypeVar("T")
Expand Down Expand Up @@ -108,7 +108,7 @@ def _request(
url,
headers=headers,
params=query,
json=body.model_dump() if isinstance(body, BaseModel) else body,
json=model_dump(body) if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
21 changes: 14 additions & 7 deletions gotrue/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
Expand All @@ -20,7 +21,13 @@
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response
from ..helpers import (
decode_jwt_payload,
model_dump,
model_validate,
parse_auth_response,
parse_user_response,
)
from ..http_clients import SyncClient
from ..timer import Timer
from ..types import (
Expand Down Expand Up @@ -529,7 +536,7 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.model_validate,
xform=partial(model_validate, AuthMFAEnrollResponse),
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -543,7 +550,7 @@ def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse:
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.model_validate,
xform=partial(model_validate, AuthMFAChallengeResponse),
)

def _challenge_and_verify(
Expand Down Expand Up @@ -572,9 +579,9 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.model_validate,
xform=partial(model_validate, AuthMFAVerifyResponse),
)
session = Session.model_validate(response.model_dump())
session = model_validate(Session, model_dump(response))
self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -587,7 +594,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.model_validate,
xform=partial(model_validate, AuthMFAUnenrollResponse),
)

def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -806,7 +813,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.model_validate(data)
return model_validate(Session, data)
except Exception:
return None

Expand Down
46 changes: 40 additions & 6 deletions gotrue/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from base64 import b64decode
from json import loads
from typing import Any, Union, cast
from typing import Any, Dict, Type, TypeVar, Union, cast

from httpx import HTTPStatusError
from pydantic import BaseModel

from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError
from .types import (
Expand All @@ -16,6 +17,39 @@
UserResponse,
)

TBaseModel = TypeVar("TBaseModel", bound=BaseModel)


def model_validate(model: Type[TBaseModel], contents) -> TBaseModel:
"""Compatibility layer between pydantic 1 and 2 for parsing an instance
of a BaseModel from varied"""
try:
# pydantic > 2
return model.model_validate(contents)
except AttributeError:
# pydantic < 2
return model.parse_obj(contents)


def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Compatibility layer between pydantic 1 and 2 for dumping a model's contents as a dict"""
try:
# pydantic > 2
return model.model_dump()
except AttributeError:
# pydantic < 2
return model.dict()


def model_dump_json(model: BaseModel) -> str:
"""Compatibility layer between pydantic 1 and 2 for dumping a model's contents as json"""
try:
# pydantic > 2
return model.model_dump_json()
except AttributeError:
# pydantic < 2
return model.json()


def parse_auth_response(data: Any) -> AuthResponse:
session: Union[Session, None] = None
Expand All @@ -27,9 +61,9 @@ def parse_auth_response(data: Any) -> AuthResponse:
and data["refresh_token"]
and data["expires_in"]
):
session = Session.model_validate(data)
session = model_validate(Session, data)
user_data = data.get("user", data)
user = User.model_validate(user_data) if user_data else None
user = model_validate(User, user_data) if user_data else None
return AuthResponse(session=session, user=user)


Expand All @@ -41,16 +75,16 @@ def parse_link_response(data: Any) -> GenerateLinkResponse:
redirect_to=data.get("redirect_to"),
verification_type=data.get("verification_type"),
)
user = User.model_validate(
{k: v for k, v in data.items() if k not in properties.model_dump()}
user = model_validate(
User, {k: v for k, v in data.items() if k not in model_dump(properties)}
)
return GenerateLinkResponse(properties=properties, user=user)


def parse_user_response(data: Any) -> UserResponse:
if "user" not in data:
data = {"user": data}
return UserResponse.model_validate(data)
return model_validate(UserResponse, data)


def get_error_message(error: Any) -> str:
Expand Down
Loading

0 comments on commit 45456c8

Please sign in to comment.