Skip to content

Commit

Permalink
refactor: add validation to project, storage, repo and session bluepr…
Browse files Browse the repository at this point in the history
…ints (#347)
  • Loading branch information
Panaetius authored and olevski committed Nov 8, 2024
1 parent 0d977e8 commit ce6a6b4
Show file tree
Hide file tree
Showing 22 changed files with 164 additions and 161 deletions.
2 changes: 1 addition & 1 deletion components/renku_data_services/authz/authz.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ async def _get_members_helper(
member = Member(
user_id=response.relationship.subject.object.object_id,
role=member_role,
resource_id=response.relationship.resource.object_id,
resource_id=ULID.from_str(response.relationship.resource.object_id),
)

yield member
Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/authz/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def with_group(self, group_id: ULID) -> "Member":
class Member(UnsavedMember):
"""Member stored in the database."""

resource_id: str | ULID
resource_id: ULID


class Change(Enum):
Expand Down
24 changes: 0 additions & 24 deletions components/renku_data_services/base_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,6 @@ async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwar
return decorator


def validate_path_project_id(
f: Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]],
) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]:
"""Decorator for a Sanic handler that validates the project_id path parameter."""
_path_project_id_regex = re.compile(r"^[A-Za-z0-9]{26}$")

@wraps(f)
async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwargs) -> _T:
project_id = cast(str | None, kwargs.get("project_id"))
if not project_id:
raise errors.ProgrammingError(
message="Could not find 'project_id' in the keyword arguments for the handler in order to validate it."
)
if not _path_project_id_regex.match(project_id):
raise errors.ValidationError(
message=f"The 'project_id' path parameter {project_id} does not match the required "
f"regex {_path_project_id_regex}"
)

return await f(request, *args, **kwargs)

return decorated_function


def validate_path_user_id(
f: Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]],
) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]:
Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/crc/apispec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: api.spec.yaml
# timestamp: 2024-10-18T11:06:20+00:00
# timestamp: 2024-08-20T07:15:17+00:00

from __future__ import annotations

Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/project/api.spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ paths:
$ref: "#/components/responses/Error"
tags:
- projects
/projects/{namespace}/{slug}:
/namespaces/{namespace}/projects/{slug}:
get:
summary: Get a project by namespace and project slug
parameters:
Expand Down
9 changes: 8 additions & 1 deletion components/renku_data_services/project/apispec_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base models for API specifications."""

from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from ulid import ULID


class BaseAPISpec(BaseModel):
Expand All @@ -13,3 +14,9 @@ class Config:
# NOTE: By default the pydantic library does not use python for regex but a rust crate
# this rust crate does not support lookahead regex syntax but we need it in this component
regex_engine = "python-re"

@field_validator("id", mode="before", check_fields=False)
@classmethod
def serialize_id(cls, id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(id)
42 changes: 18 additions & 24 deletions components/renku_data_services/project/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse
from sanic_ext import validate
from ulid import ULID
Expand All @@ -13,7 +13,6 @@
from renku_data_services.base_api.auth import (
authenticate,
only_authenticated,
validate_path_project_id,
validate_path_user_id,
)
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
Expand Down Expand Up @@ -94,7 +93,7 @@ async def _get_one(
headers = {"ETag": project.etag} if project.etag is not None else None
return validated_json(apispec.Project, self._dump_project(project), headers=headers)

return "/projects/<project_id>", ["GET"], _get_one
return "/projects/<project_id:ulid>", ["GET"], _get_one

def get_one_by_namespace_slug(self) -> BlueprintFactoryResponse:
"""Get a specific project by namespace/slug."""
Expand All @@ -112,34 +111,32 @@ async def _get_one_by_namespace_slug(
headers = {"ETag": project.etag} if project.etag is not None else None
return validated_json(apispec.Project, self._dump_project(project), headers=headers)

return "/projects/<namespace>/<slug:renku_slug>", ["GET"], _get_one_by_namespace_slug
return "/namespaces/<namespace>/projects/<slug:renku_slug>", ["GET"], _get_one_by_namespace_slug

def delete(self) -> BlueprintFactoryResponse:
"""Delete a specific project."""

@authenticate(self.authenticator)
@only_authenticated
@validate_path_project_id
async def _delete(_: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse:
await self.project_repo.delete_project(user=user, project_id=ULID.from_str(project_id))
async def _delete(_: Request, user: base_models.APIUser, project_id: ULID) -> HTTPResponse:
await self.project_repo.delete_project(user=user, project_id=project_id)
return HTTPResponse(status=204)

return "/projects/<project_id>", ["DELETE"], _delete
return "/projects/<project_id:ulid>", ["DELETE"], _delete

def patch(self) -> BlueprintFactoryResponse:
"""Partially update a specific project."""

@authenticate(self.authenticator)
@only_authenticated
@validate_path_project_id
@if_match_required
@validate(json=apispec.ProjectPatch)
async def _patch(
_: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectPatch, etag: str
_: Request, user: base_models.APIUser, project_id: ULID, body: apispec.ProjectPatch, etag: str
) -> JSONResponse:
project_patch = validate_project_patch(body)
project_update = await self.project_repo.update_project(
user=user, project_id=ULID.from_str(project_id), etag=etag, patch=project_patch
user=user, project_id=project_id, etag=etag, patch=project_patch
)

if not isinstance(project_update, project_models.ProjectUpdate):
Expand All @@ -151,15 +148,14 @@ async def _patch(
updated_project = project_update.new
return validated_json(apispec.Project, self._dump_project(updated_project))

return "/projects/<project_id>", ["PATCH"], _patch
return "/projects/<project_id:ulid>", ["PATCH"], _patch

def get_all_members(self) -> BlueprintFactoryResponse:
"""List all project members."""

@authenticate(self.authenticator)
@validate_path_project_id
async def _get_all_members(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
members = await self.project_member_repo.get_members(user, ULID.from_str(project_id))
async def _get_all_members(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse:
members = await self.project_member_repo.get_members(user, project_id)

users = []

Expand All @@ -179,35 +175,33 @@ async def _get_all_members(_: Request, user: base_models.APIUser, project_id: st
).model_dump(exclude_none=True, mode="json")
users.append(user_with_id)

return json(users)
return validated_json(apispec.ProjectMemberListResponse, users)

return "/projects/<project_id>/members", ["GET"], _get_all_members
return "/projects/<project_id:ulid>/members", ["GET"], _get_all_members

def update_members(self) -> BlueprintFactoryResponse:
"""Update or add project members."""

@authenticate(self.authenticator)
@validate_path_project_id
@validate_body_root_model(json=apispec.ProjectMemberListPatchRequest)
async def _update_members(
_: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectMemberListPatchRequest
_: Request, user: base_models.APIUser, project_id: ULID, body: apispec.ProjectMemberListPatchRequest
) -> HTTPResponse:
members = [Member(Role(i.role.value), i.id, project_id) for i in body.root]
await self.project_member_repo.update_members(user, ULID.from_str(project_id), members)
await self.project_member_repo.update_members(user, project_id, members)
return HTTPResponse(status=200)

return "/projects/<project_id>/members", ["PATCH"], _update_members
return "/projects/<project_id:ulid>/members", ["PATCH"], _update_members

def delete_member(self) -> BlueprintFactoryResponse:
"""Delete a specific project."""

@authenticate(self.authenticator)
@validate_path_project_id
@validate_path_user_id
async def _delete_member(
_: Request, user: base_models.APIUser, project_id: str, member_id: str
_: Request, user: base_models.APIUser, project_id: ULID, member_id: str
) -> HTTPResponse:
await self.project_member_repo.delete_members(user, ULID.from_str(project_id), [member_id])
await self.project_member_repo.delete_members(user, project_id, [member_id])
return HTTPResponse(status=204)

return "/projects/<project_id>/members/<member_id>", ["DELETE"], _delete_member
Expand Down
1 change: 0 additions & 1 deletion components/renku_data_services/project/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ async def update_project(
session: AsyncSession | None = None,
) -> models.ProjectUpdate:
"""Update a project entry."""
project_id_str: str = str(project_id)
if not session:
raise errors.ProgrammingError(message="A database session is required")
result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id))
Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/project/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ProjectRepositoryORM(BaseORM):

id: Mapped[int] = mapped_column("id", Integer, Identity(always=True), primary_key=True, default=None, init=False)
url: Mapped[str] = mapped_column("url", String(2000))
project_id: Mapped[Optional[str]] = mapped_column(
project_id: Mapped[Optional[ULID]] = mapped_column(
ForeignKey("projects.id", ondelete="CASCADE"), default=None, index=True
)
project: Mapped[Optional[ProjectORM]] = relationship(back_populates="repositories", default=None, repr=False)
8 changes: 3 additions & 5 deletions components/renku_data_services/repositories/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from dataclasses import dataclass
from urllib.parse import unquote

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse

import renku_data_services.base_models as base_models
from renku_data_services import errors
from renku_data_services.base_api.auth import authenticate_2
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.etag import extract_if_none_match
from renku_data_services.base_models.validation import validated_json
from renku_data_services.repositories import apispec
from renku_data_services.repositories.apispec_base import RepositoryParams
from renku_data_services.repositories.db import GitRepositoriesRepository
Expand Down Expand Up @@ -53,10 +54,7 @@ async def _get_one_repository(
if result.repository_metadata and result.repository_metadata.etag is not None
else None
)
return json(
apispec.RepositoryProviderMatch.model_validate(result).model_dump(exclude_none=True, mode="json"),
headers=headers,
)
return validated_json(apispec.RepositoryProviderMatch, result, headers=headers)

return "/repositories/<repository_url>", ["GET"], _get_one_repository

Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/secrets/apispec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: api.spec.yaml
# timestamp: 2024-08-13T13:29:49+00:00
# timestamp: 2024-08-20T07:15:21+00:00

from __future__ import annotations

Expand Down
24 changes: 24 additions & 0 deletions components/renku_data_services/session/apispec_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pydantic import BaseModel, field_validator
from ulid import ULID

from renku_data_services.session import models


class BaseAPISpec(BaseModel):
"""Base API specification."""
Expand All @@ -23,6 +25,28 @@ def serialize_ulid(cls, value: Any) -> Any:
return str(value)
return value

@field_validator("project_id", mode="before", check_fields=False)
@classmethod
def serialize_project_id(cls, project_id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(project_id)

@field_validator("environment_id", mode="before", check_fields=False)
@classmethod
def serialize_environment_id(cls, environment_id: str | ULID | None) -> str | None:
"""Custom serializer that can handle ULIDs."""
if environment_id is None:
return None
return str(environment_id)

@field_validator("environment_kind", mode="before", check_fields=False)
@classmethod
def serialize_environment_kind(cls, environment_kind: models.EnvironmentKind | str) -> str:
"""Custom serializer that can handle ULIDs."""
if isinstance(environment_kind, models.EnvironmentKind):
return environment_kind.value
return environment_kind

@field_validator("working_directory", "mount_directory", check_fields=False, mode="before")
@classmethod
def convert_path_to_string(cls, val: str | PurePosixPath) -> str:
Expand Down
11 changes: 8 additions & 3 deletions components/renku_data_services/session/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def post(self) -> BlueprintFactoryResponse:
"""Create a new session environment."""

@authenticate(self.authenticator)
@only_authenticated
@validate(json=apispec.EnvironmentPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse:
new_environment = validate_unsaved_environment(body)
Expand All @@ -62,6 +63,7 @@ def patch(self) -> BlueprintFactoryResponse:
"""Partially update a specific session environment."""

@authenticate(self.authenticator)
@only_authenticated
@validate(json=apispec.EnvironmentPatch)
async def _patch(
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
Expand All @@ -78,6 +80,7 @@ def delete(self) -> BlueprintFactoryResponse:
"""Delete a specific session environment."""

@authenticate(self.authenticator)
@only_authenticated
async def _delete(_: Request, user: base_models.APIUser, environment_id: ULID) -> HTTPResponse:
await self.session_repo.delete_environment(user=user, environment_id=environment_id)
return HTTPResponse(status=204)
Expand Down Expand Up @@ -116,6 +119,7 @@ def post(self) -> BlueprintFactoryResponse:
"""Create a new session launcher."""

@authenticate(self.authenticator)
@only_authenticated
@validate(json=apispec.SessionLauncherPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse:
new_launcher = validate_unsaved_session_launcher(body)
Expand All @@ -128,6 +132,7 @@ def patch(self) -> BlueprintFactoryResponse:
"""Partially update a specific session launcher."""

@authenticate(self.authenticator)
@only_authenticated
@validate(json=apispec.SessionLauncherPatch)
async def _patch(
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
Expand All @@ -146,6 +151,7 @@ def delete(self) -> BlueprintFactoryResponse:
"""Delete a specific session launcher."""

@authenticate(self.authenticator)
@only_authenticated
async def _delete(_: Request, user: base_models.APIUser, launcher_id: ULID) -> HTTPResponse:
await self.session_repo.delete_launcher(user=user, launcher_id=launcher_id)
return HTTPResponse(status=204)
Expand All @@ -156,9 +162,8 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
"""Get all launchers belonging to a project."""

@authenticate(self.authenticator)
@validate_path_project_id
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse:
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
return validated_json(apispec.SessionLaunchersList, launchers)

return "/projects/<project_id>/session_launchers", ["GET"], _get_launcher
return "/projects/<project_id:ulid>/session_launchers", ["GET"], _get_launcher
Loading

0 comments on commit ce6a6b4

Please sign in to comment.