Skip to content

Commit

Permalink
Add Account model + linting
Browse files Browse the repository at this point in the history
  • Loading branch information
thyb-zytek committed May 26, 2024
1 parent aafc200 commit 2d894a3
Show file tree
Hide file tree
Showing 40 changed files with 805 additions and 392 deletions.
5 changes: 0 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
args: [ --config=./app/pyproject.toml ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
Expand Down
3 changes: 1 addition & 2 deletions app/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import os
from logging.config import fileConfig

from alembic import context # noqa
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config

from alembic import context # noqa

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
Expand Down
1 change: 0 additions & 1 deletion app/alembic/versions/2ee47aeec44a_create_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
Expand Down
8 changes: 4 additions & 4 deletions app/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from core.authentication import (
FirebaseToken,
GoogleAuthorizationUrl,
GoogleToken,
RefreshToken,
RefreshTokenPayload,
UserSignIn,
)
Expand All @@ -28,7 +28,7 @@ async def login_with_email(payload: UserSignIn) -> Any:
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError:
raise FirebaseAuthError(detail="Invalid credentials")
raise FirebaseAuthError(message="Invalid credentials")


@router.get("/google")
Expand Down Expand Up @@ -58,7 +58,7 @@ async def auth_google(code: str, flow: GoogleOAuthFlowDep, request: Request) ->
raise FirebaseException()


@router.post("/refresh", response_model=GoogleToken, response_model_by_alias=False)
@router.post("/refresh", response_model=RefreshToken, response_model_by_alias=False)
async def refresh_token(payload: RefreshTokenPayload) -> Any:
try:
response = httpx.post(
Expand All @@ -68,4 +68,4 @@ async def refresh_token(payload: RefreshTokenPayload) -> Any:
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError:
raise FirebaseAuthError(detail="Invalid refresh token")
raise FirebaseAuthError(message="Invalid refresh token")
4 changes: 2 additions & 2 deletions app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from fastapi import APIRouter
from sqlalchemy import text
from sqlmodel import select

from core.dependencies import SessionDep

Expand All @@ -14,7 +14,7 @@
async def healthcheck(session: SessionDep) -> str:
"""Check if server is up and DB is reachable."""
try:
await session.exec(text("SELECT 1"))
await session.exec(select(1))
except Exception as e:
logger.error(e)
return "KO"
Expand Down
63 changes: 63 additions & 0 deletions app/api/v1/accounts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from collections.abc import Sequence

from fastapi import APIRouter
from starlette.status import HTTP_201_CREATED, HTTP_204_NO_CONTENT

from core.dependencies import FirebaseUserDep, SessionDep
from core.exceptions import NotFoundException
from crud.account import create_account, find_accounts
from crud.user import get_or_create_user
from models import Account, AccountCreate, AccountSerializer, AccountUpdate

router = APIRouter()
logger = logging.getLogger("budgly")


@router.get("/", response_model=list[AccountSerializer])
async def list_accounts(
user: FirebaseUserDep, session: SessionDep
) -> Sequence[Account]:
results = await find_accounts(session, user.uid)
return results.all()


@router.post("/", response_model=AccountSerializer, status_code=HTTP_201_CREATED)
async def create(
payload: AccountCreate, user: FirebaseUserDep, session: SessionDep
) -> Account:
user = await get_or_create_user(session, user)
account = await create_account(session, payload, user)
return account


@router.patch("/{account_id}", response_model=AccountSerializer)
async def update(
account_id: int, payload: AccountUpdate, user: FirebaseUserDep, session: SessionDep
) -> Account:
results = await find_accounts(session, user.uid, account_id)
account = results.one_or_none()
if account is None:
raise NotFoundException(f"Account(id: {account_id}) not found")

for key, value in payload.model_dump(exclude_unset=True).items():
setattr(account, key, value)

session.add(account)
await session.commit()
await session.refresh(account)

return account


@router.delete("/{account_id}", status_code=HTTP_204_NO_CONTENT)
async def delete_accounts(
account_id: int, user: FirebaseUserDep, session: SessionDep
) -> None:
results = await find_accounts(session, user.uid, account_id)
account = results.one_or_none()
if account is None:
raise NotFoundException(f"Account(id: {account_id}) not found")

await session.delete(account)
await session.commit()
2 changes: 1 addition & 1 deletion app/api/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

@router.get("/me", response_model=User)
async def get_user(firebase_user: FirebaseUserDep, session: SessionDep) -> User:
return await get_or_create_user(firebase_user, session)
return await get_or_create_user(session, firebase_user)
8 changes: 4 additions & 4 deletions app/core/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import firebase_admin # type: ignore
from fastapi import Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from firebase_admin import auth # type: ignore
from google_auth_oauthlib.flow import Flow # type: ignore # type: ignore
from firebase_admin import auth
from google_auth_oauthlib.flow import Flow # type: ignore
from pydantic import BaseModel, EmailStr, Field, HttpUrl
from starlette.datastructures import URL

Expand All @@ -30,7 +30,7 @@ class FirebaseToken(BaseModel):
name: str = Field(..., alias="displayName")


class GoogleToken(BaseModel):
class RefreshToken(BaseModel):
id_token: str
refresh_token: str

Expand Down Expand Up @@ -83,4 +83,4 @@ def get_firebase_user(token: TokenDep) -> User:
raise InvalidToken()
except Exception as e:
logger.error(e)
raise ServerError(detail=str(e))
raise ServerError(message=str(e))
9 changes: 4 additions & 5 deletions app/core/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Generator
from collections.abc import AsyncGenerator

from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Field, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession

Expand All @@ -9,9 +9,8 @@
engine = create_async_engine(str(settings.SQLALCHEMY_DATABASE_URI), future=True)


async def get_session() -> Generator[AsyncSession, None, None]:
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with async_session() as session:
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(engine) as session:
yield session


Expand Down
4 changes: 2 additions & 2 deletions app/core/exceptions/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


class FirebaseAuthError(HTTPException):
def __init__(self, detail: str) -> None:
def __init__(self, message: str) -> None:
super().__init__(
status_code=401, detail={"code": "INVALID_CREDENTIALS", "message": detail}
status_code=401, detail={"code": "INVALID_CREDENTIALS", "message": message}
)


Expand Down
4 changes: 2 additions & 2 deletions app/core/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class ServerError(HTTPException):
def __init__(self, detail: str) -> None:
def __init__(self, message: str) -> None:
super().__init__(
status_code=500, detail={"code": "SERVER_ERROR", "message": detail}
status_code=500, detail={"code": "SERVER_ERROR", "message": message}
)
8 changes: 4 additions & 4 deletions app/core/exceptions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@


class NotFoundException(HTTPException):
def __init__(self, detail: str) -> None:
def __init__(self, message: str) -> None:
super().__init__(
status_code=401, detail={"code": "NOT_FOUND", "message": detail}
status_code=404, detail={"code": "NOT_FOUND", "message": message}
)


class ValidationException(HTTPException):
def __init__(self, detail: str) -> None:
def __init__(self, message: str) -> None:
super().__init__(
status_code=400, detail={"code": "VALIDATION_ERROR", "message": detail}
status_code=400, detail={"code": "VALIDATION_ERROR", "message": message}
)
3 changes: 2 additions & 1 deletion app/core/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from rich.logging import Console, RichHandler
from rich.console import Console
from rich.logging import RichHandler

from core.config import settings

Expand Down
42 changes: 42 additions & 0 deletions app/crud/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging

from sqlalchemy import ScalarResult
from sqlalchemy.exc import IntegrityError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from core.exceptions import ValidationException
from models import Account, AccountCreate, User

logger = logging.getLogger("budgly")


async def find_accounts(
session: AsyncSession, user_uid: str, account_id: int | None = None
) -> ScalarResult[Account]:
query = select(Account).where(Account.users.any(User.uid == user_uid)) # type: ignore
if account_id is not None:
query = query.where(Account.id == account_id)
result = await session.exec(query)
return result


async def create_account(
session: AsyncSession, payload: AccountCreate, user: User
) -> Account:
try:
account = Account(**payload.model_dump(mode="json"), creator=user)
account.users.append(user)
session.add(account)
await session.commit()
await session.refresh(account)
return account
except IntegrityError as e:
logger.error(e)
await session.rollback()
if "UniqueViolationError" in str(e):
raise ValidationException(f"Account name ({payload.name}) already exists")
raise ValidationException("Failed to create account")
except Exception as e:
logger.error(e)
raise ValidationException("Failed to validate account model")
4 changes: 2 additions & 2 deletions app/crud/user.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession

from models import User


async def get_or_create_user(user_in: User, session: Session) -> User:
async def get_or_create_user(session: AsyncSession, user_in: User) -> User:
user = await session.get(User, user_in.uid)
if user is None:
user = user_in
Expand Down
11 changes: 3 additions & 8 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from sqlmodel import SQLModel # noqa

from models.account import (
Account,
AccountCreate,
AccountUpdate,
AccountSerializer,
) # noqa
from models.account import AccountSerializer # noqa
from models.account import Account, AccountCreate, AccountUpdate
from models.extra import UserAccountLink # noqa
from models.user import User # noqa
from sqlmodel import SQLModel # noqa

__all__ = [
"SQLModel",
Expand Down
73 changes: 73 additions & 0 deletions app/models/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import random

from pydantic import BaseModel, HttpUrl, field_serializer
from sqlmodel import Field, Relationship, String, UniqueConstraint

from core.db import BaseTable
from models.extra import UserAccountLink
from models.user import User

logger = logging.getLogger("budgly")


def random_color_hex_code() -> str:
def hex() -> int:
return random.randint(0, 255)

return f"#{hex():02X}{hex():02X}{hex():02X}"


class AccountCreate(BaseModel):
name: str
image: HttpUrl | None = None
color: str | None = None

@field_serializer("color")
def serialize_color(self, value: str | None) -> str | None:
if value is None and self.image is None:
return random_color_hex_code()

return value


class AccountUpdate(BaseModel):
name: str | None = None
image: HttpUrl | None = None
color: str | None = None

@field_serializer("color")
def serialize_color(self, value: str | None) -> str | None:
if value is None and "image" in self.model_fields_set and self.image is None:
return random_color_hex_code()

return value


class Account(BaseTable, table=True):
name: str
image: HttpUrl | None = Field(sa_type=String)
color: str | None
creator_id: str = Field(foreign_key="user.uid")

creator: User = Relationship(back_populates="created_accounts")
users: list[User] = Relationship(
back_populates="accounts",
link_model=UserAccountLink,
sa_relationship_kwargs={"lazy": "selectin"},
)

__table_args__ = (
UniqueConstraint(
"creator_id",
"name",
name="unique_name_by_creator",
),
)


class AccountSerializer(BaseTable):
name: str
image: HttpUrl | None = None
color: str | None = None
creator: User
Loading

0 comments on commit 2d894a3

Please sign in to comment.