Skip to content

Commit

Permalink
refactor: add sqlalchemy ulid type to all relevant models
Browse files Browse the repository at this point in the history
  • Loading branch information
Panaetius committed Aug 20, 2024
1 parent e33c2f8 commit 90b6fc8
Show file tree
Hide file tree
Showing 32 changed files with 245 additions and 146 deletions.
7 changes: 4 additions & 3 deletions bases/renku_data_services/background_jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SubjectFilter,
WriteRelationshipsRequest,
)
from ulid import ULID

from renku_data_services.authz.authz import Authz, ResourceType, _AuthzConverter, _Relation
from renku_data_services.authz.models import Scope
Expand Down Expand Up @@ -117,7 +118,7 @@ async def fix_mismatched_project_namespace_ids(config: SyncConfig) -> None:
relation=rel.relationship.relation,
subject=SubjectReference(
object=ObjectReference(
object_type=ResourceType.group.value, object_id=correct_group_id
object_type=ResourceType.group.value, object_id=str(correct_group_id)
)
),
),
Expand Down Expand Up @@ -169,7 +170,7 @@ async def migrate_groups_make_all_public(config: SyncConfig) -> None:
all_users = SubjectReference(object=_AuthzConverter.all_users())
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
for group_id in groups_to_process:
group_res = _AuthzConverter.group(group_id)
group_res = _AuthzConverter.group(ULID.from_str(group_id))
all_users_are_viewers = Relationship(
resource=group_res,
relation=_Relation.public_viewer.value,
Expand Down Expand Up @@ -228,7 +229,7 @@ async def migrate_user_namespaces_make_all_public(config: SyncConfig) -> None:
all_users = SubjectReference(object=_AuthzConverter.all_users())
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
for ns_id in namespaces_to_process:
namespace_res = _AuthzConverter.user_namespace(ns_id)
namespace_res = _AuthzConverter.user_namespace(ULID.from_str(ns_id))
all_users_are_viewers = Relationship(
resource=namespace_res,
relation=_Relation.public_viewer.value,
Expand Down
100 changes: 58 additions & 42 deletions components/renku_data_services/authz/authz.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion components/renku_data_services/authz/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from enum import Enum

from ulid import ULID

from renku_data_services.errors import errors
from renku_data_services.namespace.apispec import GroupRole

Expand Down Expand Up @@ -56,7 +58,7 @@ class Member:

role: Role
user_id: str
resource_id: str
resource_id: str | ULID


class Change(Enum):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base models for API specifications."""

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from ulid import ULID


class BaseAPISpec(BaseModel):
Expand All @@ -14,6 +15,12 @@ class Config:
# this rust crate does not support lookahead regex syntax but we need it in this component
regex_engine = "python-re"

@field_validator("id", mode="before", check_fields=False)
@classmethod
def serialize_id(cls, id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(id)


class AuthorizeParams(BaseAPISpec):
"""The schema for the query parameters used in the authorize request."""
Expand Down
13 changes: 7 additions & 6 deletions components/renku_data_services/connected_services/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sanic.log import logger
from sanic.response import JSONResponse
from sanic_ext import validate
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
Expand Down Expand Up @@ -150,34 +151,34 @@ def get_one(self) -> BlueprintFactoryResponse:
"""Get a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_one(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
connection = await self.connected_services_repo.get_oauth2_connection(
connection_id=connection_id, user=user
)
return validated_json(apispec.Connection, connection)

return "/oauth2/connections/<connection_id>", ["GET"], _get_one
return "/oauth2/connections/<connection_id:ulid>", ["GET"], _get_one

def get_account(self) -> BlueprintFactoryResponse:
"""Get the account information for a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_account(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
account = await self.connected_services_repo.get_oauth2_connected_account(
connection_id=connection_id, user=user
)
return validated_json(apispec.ConnectedAccount, account)

return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account
return "/oauth2/connections/<connection_id:ulid>/account", ["GET"], _get_account

def get_token(self) -> BlueprintFactoryResponse:
"""Get the access token for a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_token(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_token(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
token = await self.connected_services_repo.get_oauth2_connection_token(
connection_id=connection_id, user=user
)
return json(token.dump_for_api())

return "/oauth2/connections/<connection_id>/token", ["GET"], _get_token
return "/oauth2/connections/<connection_id:ulid>/token", ["GET"], _get_token
11 changes: 7 additions & 4 deletions components/renku_data_services/connected_services/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services import errors
Expand Down Expand Up @@ -282,7 +283,7 @@ async def get_oauth2_connections(
connections = result.all()
return [c.dump() for c in connections]

async def get_oauth2_connection(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2Connection:
async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection:
"""Get one OAuth2 connection from the database."""
if not user.is_authenticated or user.id is None:
raise errors.MissingResourceError(
Expand All @@ -303,7 +304,7 @@ async def get_oauth2_connection(self, connection_id: str, user: base_models.APIU
return connection.dump()

async def get_oauth2_connected_account(
self, connection_id: str, user: base_models.APIUser
self, connection_id: ULID, user: base_models.APIUser
) -> models.ConnectedAccount:
"""Get the account information from a OAuth2 connection."""
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter):
Expand All @@ -316,7 +317,9 @@ async def get_oauth2_connected_account(
account = adapter.api_validate_account_response(response)
return account

async def get_oauth2_connection_token(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2TokenSet:
async def get_oauth2_connection_token(
self, connection_id: ULID, user: base_models.APIUser
) -> models.OAuth2TokenSet:
"""Get the OAuth2 access token from one connection from the database."""
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _):
await oauth2_client.ensure_active_token(oauth2_client.token)
Expand All @@ -325,7 +328,7 @@ async def get_oauth2_connection_token(self, connection_id: str, user: base_model

@asynccontextmanager
async def get_async_oauth2_client(
self, connection_id: str, user: base_models.APIUser
self, connection_id: ULID, user: base_models.APIUser
) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]:
"""Get the AsyncOAuth2Client for the given connection_id and user."""
if not user.is_authenticated or user.id is None:
Expand Down
4 changes: 3 additions & 1 deletion components/renku_data_services/connected_services/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import UTC, datetime
from typing import Any

from ulid import ULID

from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind


Expand All @@ -28,7 +30,7 @@ class OAuth2Client:
class OAuth2Connection:
"""OAuth2 connection model."""

id: str
id: ULID
provider_id: str
status: ConnectionStatus

Expand Down
3 changes: 2 additions & 1 deletion components/renku_data_services/connected_services/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from renku_data_services.connected_services import models
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
from renku_data_services.utils.sqlalchemy import ULIDType

JSONVariant = JSON().with_variant(JSONB(), "postgresql")

Expand Down Expand Up @@ -72,7 +73,7 @@ class OAuth2ConnectionORM(BaseORM):
"""An OAuth2 connection."""

__tablename__ = "oauth2_connections"
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
user_id: Mapped[str] = mapped_column("user_id", String())
client_id: Mapped[str] = mapped_column(ForeignKey(OAuth2ClientORM.id, ondelete="CASCADE"), index=True)
client: Mapped[OAuth2ClientORM] = relationship(init=False, repr=False)
Expand Down
32 changes: 18 additions & 14 deletions components/renku_data_services/message_queue/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def to_events(
raise errors.EventError(
message=f"Cannot create an event of type {event_type} for a project which has no ID"
)
project_id_str = str(project.id)
match event_type:
case v2.ProjectCreated:
return [
Event(
"project.created",
v2.ProjectCreated(
id=project.id,
id=project_id_str,
name=project.name,
namespace=project.namespace.slug,
slug=project.slug,
Expand All @@ -56,7 +57,7 @@ def to_events(
Event(
"projectAuth.added",
v2.ProjectMemberAdded(
projectId=project.id,
projectId=project_id_str,
userId=project.created_by,
role=v2.MemberRole.OWNER,
),
Expand All @@ -67,7 +68,7 @@ def to_events(
Event(
"project.updated",
v2.ProjectUpdated(
id=project.id,
id=project_id_str,
name=project.name,
namespace=project.namespace.slug,
slug=project.slug,
Expand All @@ -79,7 +80,7 @@ def to_events(
)
]
case v2.ProjectRemoved:
return [Event("project.removed", v2.ProjectRemoved(id=project.id))]
return [Event("project.removed", v2.ProjectRemoved(id=project_id_str))]
case _:
raise errors.EventError(message=f"Trying to convert a project to an unknown event type {event_type}")

Expand Down Expand Up @@ -145,13 +146,14 @@ class _ProjectAuthzEventConverter:
def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]:
output: list[Event] = []
for change in member_changes:
resource_id = str(change.member.resource_id)
match change.change:
case authz_models.Change.UPDATE:
output.append(
Event(
"projectAuth.updated",
v2.ProjectMemberUpdated(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -162,7 +164,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"projectAuth.removed",
v2.ProjectMemberRemoved(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
),
)
Expand All @@ -172,7 +174,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"projectAuth.added",
v2.ProjectMemberAdded(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -191,13 +193,14 @@ class _GroupAuthzEventConverter:
def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]:
output: list[Event] = []
for change in member_changes:
resource_id = str(change.member.resource_id)
match change.change:
case authz_models.Change.UPDATE:
output.append(
Event(
"memberGroup.updated",
v2.ProjectMemberUpdated(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -208,7 +211,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"memberGroup.removed",
v2.ProjectMemberRemoved(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
),
)
Expand All @@ -218,7 +221,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"memberGroup.added",
v2.ProjectMemberAdded(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -239,32 +242,33 @@ def to_events(group: group_models.Group, event_type: type[AvroModel] | type[even
raise errors.ProgrammingError(
message="Cannot send group events to the message queue for a group that does not have an ID"
)
group_id = str(group.id)
match event_type:
case v2.GroupAdded:
return [
Event(
"group.added",
v2.GroupAdded(
id=group.id, name=group.name, description=group.description, namespace=group.slug
id=group_id, name=group.name, description=group.description, namespace=group.slug
),
),
Event(
"memberGroup.added",
v2.GroupMemberAdded(
groupId=group.id,
groupId=group_id,
userId=group.created_by,
role=v2.MemberRole.OWNER,
),
),
]
case v2.GroupRemoved:
return [Event("group.removed", v2.GroupRemoved(id=group.id))]
return [Event("group.removed", v2.GroupRemoved(id=group_id))]
case v2.GroupUpdated:
return [
Event(
"group.updated",
v2.GroupUpdated(
id=group.id, name=group.name, description=group.description, namespace=group.slug
id=group_id, name=group.name, description=group.description, namespace=group.slug
),
)
]
Expand Down
9 changes: 8 additions & 1 deletion components/renku_data_services/namespace/apispec_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base models for API specifications."""

from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from ulid import ULID


class BaseAPISpec(BaseModel):
Expand All @@ -13,3 +14,9 @@ class Config:
# NOTE: By default the pydantic library does not use python for regex but a rust crate
# this rust crate does not support lookahead regex syntax but we need it in this component
regex_engine = "python-re"

@field_validator("id", mode="before", check_fields=False)
@classmethod
def serialize_id(cls, id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(id)
Loading

0 comments on commit 90b6fc8

Please sign in to comment.