From 2958eb33bde4392c1546540618185419cc3391f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Tue, 28 Feb 2023 12:31:32 +0100 Subject: [PATCH] Rename table users_workspaces to workspaces_users, related models and schemas --- ...ee58fbb4_create_workspaces_users_table.py} | 10 ++-- .../server/apis/v0/handlers/workspaces.py | 14 ++--- src/argilla/server/contexts/accounts.py | 27 +++++---- src/argilla/server/models.py | 12 ++-- src/argilla/server/security/model.py | 2 +- tests/client/test_api.py | 4 +- tests/conftest.py | 4 +- tests/factories.py | 18 +++--- .../test_log_for_text_classification.py | 4 +- tests/server/api/v0/test_workspaces.py | 59 ++++++++++--------- 10 files changed, 78 insertions(+), 76 deletions(-) rename src/argilla/server/alembic/versions/{1769ee58fbb4_create_users_workspaces_table.py => 1769ee58fbb4_create_workspaces_users_table.py} (88%) diff --git a/src/argilla/server/alembic/versions/1769ee58fbb4_create_users_workspaces_table.py b/src/argilla/server/alembic/versions/1769ee58fbb4_create_workspaces_users_table.py similarity index 88% rename from src/argilla/server/alembic/versions/1769ee58fbb4_create_users_workspaces_table.py rename to src/argilla/server/alembic/versions/1769ee58fbb4_create_workspaces_users_table.py index acb0a99e3b..fc203fbd15 100644 --- a/src/argilla/server/alembic/versions/1769ee58fbb4_create_users_workspaces_table.py +++ b/src/argilla/server/alembic/versions/1769ee58fbb4_create_workspaces_users_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""create users_workspaces table +"""create workspaces_users table Revision ID: 1769ee58fbb4 Revises: 82a5a88a3fa5 @@ -31,17 +31,17 @@ def upgrade() -> None: op.create_table( - "users_workspaces", + "workspaces_users", sa.Column("id", sa.Uuid, primary_key=True), - sa.Column("user_id", sa.Uuid, sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True), sa.Column( "workspace_id", sa.Uuid, sa.ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True ), + sa.Column("user_id", sa.Uuid, sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True), sa.Column("inserted_at", sa.DateTime, nullable=False), sa.Column("updated_at", sa.DateTime, nullable=False), - sa.UniqueConstraint("user_id", "workspace_id", name="user_id_workspace_id_uq"), + sa.UniqueConstraint("workspace_id", "user_id", name="workspace_id_user_id_uq"), ) def downgrade() -> None: - op.drop_table("users_workspaces") + op.drop_table("workspaces_users") diff --git a/src/argilla/server/apis/v0/handlers/workspaces.py b/src/argilla/server/apis/v0/handlers/workspaces.py index d24ea31e4b..7d4ed65f79 100644 --- a/src/argilla/server/apis/v0/handlers/workspaces.py +++ b/src/argilla/server/apis/v0/handlers/workspaces.py @@ -25,9 +25,9 @@ from argilla.server.security import auth from argilla.server.security.model import ( User, - UserWorkspaceCreate, Workspace, WorkspaceCreate, + WorkspaceUserCreate, ) router = APIRouter(tags=["workspaces"]) @@ -95,9 +95,9 @@ def create_workspace_user( if not user: raise EntityNotFoundError(name=str(user_id), type=User) - user_workspace = accounts.create_user_workspace(db, UserWorkspaceCreate(user_id=user_id, workspace_id=workspace_id)) + workspace_user = accounts.create_workspace_user(db, WorkspaceUserCreate(workspace_id=workspace_id, user_id=user_id)) - return User.from_orm(user_workspace.user) + return User.from_orm(workspace_user.user) @router.delete("/workspaces/{workspace_id}/users/{user_id}", response_model=User, response_model_exclude_none=True) @@ -108,11 +108,11 @@ def delete_workspace_user( user_id: UUID, current_user: User = Security(auth.get_user, scopes=[]), ): - user_workspace = accounts.get_user_workspace_by_user_id_and_workspace_id(db, user_id, workspace_id) - if not user_workspace: + workspace_user = accounts.get_workspace_user_by_workspace_id_and_user_id(db, workspace_id, user_id) + if not workspace_user: raise EntityNotFoundError(name=str(user_id), type=User) - user = user_workspace.user - accounts.delete_user_workspace(db, user_workspace) + user = workspace_user.user + accounts.delete_workspace_user(db, workspace_user) return User.from_orm(user) diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index 15b6d85e0b..13493cbb6c 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -14,11 +14,11 @@ from uuid import UUID -from argilla.server.models import User, UserWorkspace, Workspace +from argilla.server.models import User, Workspace, WorkspaceUser from argilla.server.security.model import ( UserCreate, - UserWorkspaceCreate, WorkspaceCreate, + WorkspaceUserCreate, ) from passlib.context import CryptContext from sqlalchemy.orm import Session @@ -26,27 +26,28 @@ _CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") -def get_user_workspace_by_user_id_and_workspace_id(db: Session, user_id: UUID, workspace_id: UUID): - return db.query(UserWorkspace).filter_by(user_id=user_id, workspace_id=workspace_id).first() +def get_workspace_user_by_workspace_id_and_user_id(db: Session, workspace_id: UUID, user_id: UUID): + return db.query(WorkspaceUser).filter_by(workspace_id=workspace_id, user_id=user_id).first() -def create_user_workspace(db: Session, user_workspace_create: UserWorkspaceCreate): - user_workspace = UserWorkspace( - user_id=user_workspace_create.user_id, workspace_id=user_workspace_create.workspace_id +def create_workspace_user(db: Session, workspace_user_create: WorkspaceUserCreate): + workspace_user = WorkspaceUser( + workspace_id=workspace_user_create.workspace_id, + user_id=workspace_user_create.user_id, ) - db.add(user_workspace) + db.add(workspace_user) db.commit() - db.refresh(user_workspace) + db.refresh(workspace_user) - return user_workspace + return workspace_user -def delete_user_workspace(db: Session, user_workspace: UserWorkspace): - db.delete(user_workspace) +def delete_workspace_user(db: Session, workspace_user: WorkspaceUser): + db.delete(workspace_user) db.commit() - return user_workspace + return workspace_user def get_workspace_by_id(db: Session, workspace_id: UUID): diff --git a/src/argilla/server/models.py b/src/argilla/server/models.py index f7ab8c3487..30443b9f87 100644 --- a/src/argilla/server/models.py +++ b/src/argilla/server/models.py @@ -33,18 +33,18 @@ def default_inserted_at(context): return context.get_current_parameters()["inserted_at"] -class UserWorkspace(Base): - __tablename__ = "users_workspaces" +class WorkspaceUser(Base): + __tablename__ = "workspaces_users" id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) - user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id")) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id")) + user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id")) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=default_inserted_at, onupdate=datetime.utcnow) - user: Mapped["User"] = relationship(viewonly=True) workspace: Mapped["Workspace"] = relationship(viewonly=True) + user: Mapped["User"] = relationship(viewonly=True) class Workspace(Base): @@ -57,7 +57,7 @@ class Workspace(Base): updated_at: Mapped[datetime] = mapped_column(default=default_inserted_at, onupdate=datetime.utcnow) users: Mapped[List["User"]] = relationship( - secondary="users_workspaces", back_populates="workspaces", order_by=UserWorkspace.inserted_at.asc() + secondary="workspaces_users", back_populates="workspaces", order_by=WorkspaceUser.inserted_at.asc() ) @@ -75,5 +75,5 @@ class User(Base): updated_at: Mapped[datetime] = mapped_column(default=default_inserted_at, onupdate=datetime.utcnow) workspaces: Mapped[List["Workspace"]] = relationship( - secondary="users_workspaces", back_populates="users", order_by=UserWorkspace.inserted_at.asc() + secondary="workspaces_users", back_populates="users", order_by=WorkspaceUser.inserted_at.asc() ) diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index 6b24bb260f..9baa4930c0 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -34,7 +34,7 @@ WORKSPACE_NAME_PATTERN = re.compile(_WORKSPACE_NAME_REGEX) -class UserWorkspaceCreate(BaseModel): +class WorkspaceUserCreate(BaseModel): user_id: UUID workspace_id: UUID diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 52d4f6c105..6e25d9c0c3 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -47,7 +47,7 @@ TextClassificationSearchResults, ) from argilla.server.contexts import accounts -from argilla.server.security.model import UserWorkspaceCreate, WorkspaceCreate +from argilla.server.security.model import WorkspaceCreate, WorkspaceUserCreate from sqlalchemy.orm import Session from tests.helpers import SecuredClient @@ -454,7 +454,7 @@ def test_dataset_copy_to_another_workspace(mocked_client, admin: User, db: Sessi new_workspace_name = "my-fun-workspace" workspace = accounts.create_workspace(db, WorkspaceCreate(name=new_workspace_name)) - accounts.create_user_workspace(db, UserWorkspaceCreate(user_id=admin.id, workspace_id=workspace.id)) + accounts.create_workspace_user(db, WorkspaceUserCreate(workspace_id=workspace.id, user_id=admin.id)) mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.delete(f"/api/datasets/{dataset_copy}") diff --git a/tests/conftest.py b/tests/conftest.py index e843381cc1..2ae9c7a57c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ from argilla.server.commons import telemetry from argilla.server.contexts import accounts from argilla.server.database import SessionLocal -from argilla.server.models import User, UserWorkspace, Workspace +from argilla.server.models import User, Workspace, WorkspaceUser from argilla.server.seeds import test_seeds try: @@ -50,7 +50,7 @@ def db(): session.query(User).delete() session.query(Workspace).delete() - session.query(UserWorkspace).delete() + session.query(WorkspaceUser).delete() session.commit() diff --git a/tests/factories.py b/tests/factories.py index 4fae222825..bdaef42e9b 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -14,7 +14,14 @@ import factory from argilla.server.database import SessionLocal -from argilla.server.models import User, UserWorkspace, Workspace +from argilla.server.models import User, Workspace, WorkspaceUser + + +class WorkspaceUserFactory(factory.alchemy.SQLAlchemyModelFactory): + class Meta: + model = WorkspaceUser + sqlalchemy_session = SessionLocal() + sqlalchemy_session_persistence = "commit" class WorkspaceFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -36,12 +43,3 @@ class Meta: username = factory.Sequence(lambda n: f"username-{n}") api_key = factory.Sequence(lambda n: f"api-key-{n}") password_hash = "$2y$05$eaw.j2Kaw8s8vpscVIZMfuqSIX3OLmxA21WjtWicDdn0losQ91Hw." - - -class UserWorkspaceFactory(factory.alchemy.SQLAlchemyModelFactory): - class Meta: - model = UserWorkspace - sqlalchemy_session = SessionLocal() - sqlalchemy_session_persistence = "commit" - - # TODO: Define relationships with user and workspace diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 0a02c59bdb..6ec52366d7 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -23,7 +23,7 @@ ) from argilla.server.contexts import accounts from argilla.server.models import User -from argilla.server.security.model import UserWorkspaceCreate, WorkspaceCreate +from argilla.server.security.model import WorkspaceCreate, WorkspaceUserCreate from argilla.server.settings import settings from sqlalchemy.orm import Session @@ -196,7 +196,7 @@ def test_log_data_in_several_workspaces(mocked_client: SecuredClient, admin: Use text = "This is a text" workspace = accounts.create_workspace(db, WorkspaceCreate(name=workspace_name)) - accounts.create_user_workspace(db, UserWorkspaceCreate(user_id=admin.id, workspace_id=workspace.id)) + accounts.create_workspace_user(db, WorkspaceUserCreate(workspace_id=workspace.id, user_id=admin.id)) api = Argilla() diff --git a/tests/server/api/v0/test_workspaces.py b/tests/server/api/v0/test_workspaces.py index 209077b8d5..02a6949d83 100644 --- a/tests/server/api/v0/test_workspaces.py +++ b/tests/server/api/v0/test_workspaces.py @@ -15,11 +15,11 @@ from uuid import UUID, uuid4 import pytest -from argilla.server.models import User, UserWorkspace, Workspace +from argilla.server.models import User, Workspace, WorkspaceUser from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from tests.factories import UserFactory, UserWorkspaceFactory, WorkspaceFactory +from tests.factories import UserFactory, WorkspaceFactory, WorkspaceUserFactory def test_list_workspaces(client: TestClient, admin_auth_header: dict): @@ -89,16 +89,16 @@ def test_delete_workspace_with_nonexistent_workspace_id(): def test_list_workspace_users(client: TestClient, db: Session, admin_auth_header: dict): workspace_a = WorkspaceFactory.create() - UserWorkspaceFactory.create(user_id=UserFactory.create(username="username-a").id, workspace_id=workspace_a.id) - UserWorkspaceFactory.create(user_id=UserFactory.create(username="username-b").id, workspace_id=workspace_a.id) - UserWorkspaceFactory.create(user_id=UserFactory.create(username="username-c").id, workspace_id=workspace_a.id) + WorkspaceUserFactory.create(workspace_id=workspace_a.id, user_id=UserFactory.create(username="username-a").id) + WorkspaceUserFactory.create(workspace_id=workspace_a.id, user_id=UserFactory.create(username="username-b").id) + WorkspaceUserFactory.create(workspace_id=workspace_a.id, user_id=UserFactory.create(username="username-c").id) WorkspaceFactory.create(users=[UserFactory.build(), UserFactory.build()]) response = client.get(f"/api/workspaces/{workspace_a.id}/users", headers=admin_auth_header) assert response.status_code == 200 - assert db.query(UserWorkspace).count() == 5 + assert db.query(WorkspaceUser).count() == 5 response_body = response.json() assert list(map(lambda u: u["username"], response_body)) == ["username-a", "username-b", "username-c"] @@ -118,8 +118,8 @@ def test_create_workspace_user(client: TestClient, db: Session, admin: User, adm response = client.post(f"/api/workspaces/{workspace.id}/users/{admin.id}", headers=admin_auth_header) assert response.status_code == 200 - assert db.query(UserWorkspace).count() == 1 - assert db.query(UserWorkspace).filter_by(user_id=admin.id, workspace_id=workspace.id).first() + assert db.query(WorkspaceUser).count() == 1 + assert db.query(WorkspaceUser).filter_by(workspace_id=workspace.id, user_id=admin.id).first() response_body = response.json() assert response_body["id"] == str(admin.id) @@ -132,7 +132,7 @@ def test_create_workspace_user_without_authentication(client: TestClient, db: Se response = client.post(f"/api/workspaces/{workspace.id}/users/{admin.id}") assert response.status_code == 401 - assert db.query(UserWorkspace).count() == 0 + assert db.query(WorkspaceUser).count() == 0 def test_create_workspace_user_with_nonexistent_workspace_id( @@ -141,7 +141,7 @@ def test_create_workspace_user_with_nonexistent_workspace_id( response = client.post(f"/api/workspaces/{uuid4()}/users/{admin.id}", headers=admin_auth_header) assert response.status_code == 404 - assert db.query(UserWorkspace).count() == 0 + assert db.query(WorkspaceUser).count() == 0 def test_create_workspace_user_with_nonexistent_user_id(client: TestClient, db: Session, admin_auth_header: dict): @@ -150,55 +150,58 @@ def test_create_workspace_user_with_nonexistent_user_id(client: TestClient, db: response = client.post(f"/api/workspaces/{workspace.id}/users/{uuid4()}", headers=admin_auth_header) assert response.status_code == 404 - assert db.query(UserWorkspace).count() == 0 + assert db.query(WorkspaceUser).count() == 0 def test_delete_workspace_user(client: TestClient, db: Session, admin_auth_header: dict): - user_workspace = UserWorkspaceFactory.create( - user_id=UserFactory.create().id, workspace_id=WorkspaceFactory.create().id + workspace_user = WorkspaceUserFactory.create( + workspace_id=WorkspaceFactory.create().id, + user_id=UserFactory.create().id, ) response = client.delete( - f"/api/workspaces/{user_workspace.workspace_id}/users/{user_workspace.user_id}", headers=admin_auth_header + f"/api/workspaces/{workspace_user.workspace_id}/users/{workspace_user.user_id}", headers=admin_auth_header ) assert response.status_code == 200 - assert db.query(UserWorkspace).count() == 0 + assert db.query(WorkspaceUser).count() == 0 response_body = response.json() - assert response_body["id"] == str(user_workspace.user_id) + assert response_body["id"] == str(workspace_user.user_id) def test_delete_workspace_user_without_authentication(client: TestClient, db: Session): - user_workspace = UserWorkspaceFactory.create( - user_id=UserFactory.create().id, workspace_id=WorkspaceFactory.create().id + workspace_user = WorkspaceUserFactory.create( + workspace_id=WorkspaceFactory.create().id, + user_id=UserFactory.create().id, ) - response = client.delete(f"/api/workspaces/{user_workspace.workspace_id}/users/{user_workspace.user_id}") + response = client.delete(f"/api/workspaces/{workspace_user.workspace_id}/users/{workspace_user.user_id}") assert response.status_code == 401 - assert db.query(UserWorkspace).count() == 1 + assert db.query(WorkspaceUser).count() == 1 def test_delete_workspace_user_with_nonexistent_workspace_id(client: TestClient, db: Session, admin_auth_header: dict): - user_workspace = UserWorkspaceFactory.create( - user_id=UserFactory.create().id, workspace_id=WorkspaceFactory.create().id + workspace_user = WorkspaceUserFactory.create( + workspace_id=WorkspaceFactory.create().id, user_id=UserFactory.create().id ) - response = client.delete(f"/api/workspaces/{uuid4()}/users/{user_workspace.user_id}", headers=admin_auth_header) + response = client.delete(f"/api/workspaces/{uuid4()}/users/{workspace_user.user_id}", headers=admin_auth_header) assert response.status_code == 404 - assert db.query(UserWorkspace).count() == 1 + assert db.query(WorkspaceUser).count() == 1 def test_delete_workspace_user_with_nonexistent_user_id(client: TestClient, db: Session, admin_auth_header: dict): - user_workspace = UserWorkspaceFactory.create( - user_id=UserFactory.create().id, workspace_id=WorkspaceFactory.create().id + workspace_user = WorkspaceUserFactory.create( + workspace_id=WorkspaceFactory.create().id, + user_id=UserFactory.create().id, ) response = client.delete( - f"/api/workspaces/{user_workspace.workspace_id}/users/{uuid4()}", headers=admin_auth_header + f"/api/workspaces/{workspace_user.workspace_id}/users/{uuid4()}", headers=admin_auth_header ) assert response.status_code == 404 - assert db.query(UserWorkspace).count() == 1 + assert db.query(WorkspaceUser).count() == 1