From dab26b43374d85979f68726682e1de1a18bc13ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 6 Sep 2023 15:36:06 +0200 Subject: [PATCH] feat: handle `PermissionError` in CLI (#3717) # Description This PR adds a way to handle the `PermissionError`s that could be raised when executing a command with a logged in user that doesn't have the required permissions because of his role. In addition, the module `argilla.tasks.async_typer` has been renamed to `argilla.tasks.typer_ext` (Typer extension) and the `AsyncTyper` class has been renamed to `ArgillaTyper`. **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** In a local development environment: - [x] Login with an API Key linked to an annotator account - [x] Trying to execute a command that required to be owner or admin printed the message describing that the logged in user doesn't have enough permissions **Checklist** - [ ] I added relevant documentation - [x] I followed the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --- CHANGELOG.md | 1 + src/argilla/tasks/app.py | 24 +++++++++++++-- src/argilla/tasks/rich.py | 6 ++-- src/argilla/tasks/server/database/migrate.py | 4 +-- .../tasks/server/database/revisions.py | 4 +-- .../tasks/server/database/users/__main__.py | 4 +-- .../tasks/server/database/users/create.py | 4 +-- .../server/database/users/create_default.py | 4 +-- .../tasks/server/database/users/migrate.py | 4 +-- .../tasks/server/database/users/update.py | 4 +-- .../tasks/{async_typer.py => typer_ext.py} | 29 ++++++++++++++++--- tests/unit/tasks/conftest.py | 4 +-- 12 files changed, 68 insertions(+), 24 deletions(-) rename src/argilla/tasks/{async_typer.py => typer_ext.py} (64%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c09007997..da79f8fe38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ These are the section headers that we use: - Added `info` command to get info about the used Argilla client and server ([#3707](https://github.com/argilla-io/argilla/pull/3707)). - Added `datasets delete` command to delete a `FeedbackDataset` from Argilla ([#3703](https://github.com/argilla-io/argilla/pull/3703)). - Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)). +- Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)). - Added `workspaces add-user` command to add a user to workspace ([#3712](https://github.com/argilla-io/argilla/pull/3712)). ### Changed diff --git a/src/argilla/tasks/app.py b/src/argilla/tasks/app.py index f23bfa0a57..351a55758b 100644 --- a/src/argilla/tasks/app.py +++ b/src/argilla/tasks/app.py @@ -24,12 +24,32 @@ whoami_app, workspaces_app, ) -from argilla.tasks.async_typer import AsyncTyper +from argilla.tasks.typer_ext import ArgillaTyper from argilla.utils.dependency import is_package_with_extras_installed warnings.simplefilter("ignore", UserWarning) -app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) +app = ArgillaTyper(help="Argilla CLI", no_args_is_help=True) + + +@app.error_handler(PermissionError) +def handler_permission_error(e: PermissionError) -> None: + import sys + + from rich.console import Console + + from argilla.tasks.rich import get_argilla_themed_panel + + panel = get_argilla_themed_panel( + "Logged in user doesn't have enough permissions to execute this command", + title="Not enough permissions", + title_align="left", + success=False, + ) + + Console().print(panel) + sys.exit(1) + app.add_typer(datasets_app, name="datasets") app.add_typer(info_app, name="info") diff --git a/src/argilla/tasks/rich.py b/src/argilla/tasks/rich.py index 46ed701f9c..851c1e5ba1 100644 --- a/src/argilla/tasks/rich.py +++ b/src/argilla/tasks/rich.py @@ -28,5 +28,7 @@ def get_argilla_themed_table(title: str, **kwargs: Any) -> Table: return Table(title=title, border_style=_ARGILLA_BORDER_STYLE, **kwargs) -def get_argilla_themed_panel(renderable: "RenderableType", **kwargs: Any) -> Panel: - return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, **kwargs) +def get_argilla_themed_panel(renderable: "RenderableType", title: str, success: bool = True, **kwargs: Any) -> Panel: + if success: + title = f"[green]{title}" + return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, title=title, **kwargs) diff --git a/src/argilla/tasks/server/database/migrate.py b/src/argilla/tasks/server/database/migrate.py index c73555dc2b..52112b87a3 100644 --- a/src/argilla/tasks/server/database/migrate.py +++ b/src/argilla/tasks/server/database/migrate.py @@ -21,7 +21,7 @@ from alembic.util import CommandError from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database import utils @@ -49,4 +49,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R if __name__ == "__main__": - async_typer.run(migrate_db) + typer_ext.run(migrate_db) diff --git a/src/argilla/tasks/server/database/revisions.py b/src/argilla/tasks/server/database/revisions.py index 922184abb0..eb2f4c99e8 100644 --- a/src/argilla/tasks/server/database/revisions.py +++ b/src/argilla/tasks/server/database/revisions.py @@ -16,7 +16,7 @@ import typer from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database import utils @@ -41,4 +41,4 @@ def revisions(): if __name__ == "__main__": - async_typer.run(revisions) + typer_ext.run(revisions) diff --git a/src/argilla/tasks/server/database/users/__main__.py b/src/argilla/tasks/server/database/users/__main__.py index 7f0ac33050..85362237dd 100644 --- a/src/argilla/tasks/server/database/users/__main__.py +++ b/src/argilla/tasks/server/database/users/__main__.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla.tasks.async_typer import AsyncTyper +from argilla.tasks.typer_ext import ArgillaTyper from .create import create from .create_default import create_default from .migrate import migrate from .update import update -app = AsyncTyper( +app = ArgillaTyper( help="CLI commands for user and workspace management using the database connection", no_args_is_help=True ) diff --git a/src/argilla/tasks/server/database/users/create.py b/src/argilla/tasks/server/database/users/create.py index f08d307855..3dbedc9ed9 100644 --- a/src/argilla/tasks/server/database/users/create.py +++ b/src/argilla/tasks/server/database/users/create.py @@ -25,7 +25,7 @@ UserCreate, WorkspaceCreate, ) -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database.users.utils import get_or_new_workspace USER_API_KEY_MIN_LENGTH = 8 @@ -130,4 +130,4 @@ async def create( if __name__ == "__main__": - async_typer.run(create) + typer_ext.run(create) diff --git a/src/argilla/tasks/server/database/users/create_default.py b/src/argilla/tasks/server/database/users/create_default.py index 88c0caff0f..df4cce60ba 100644 --- a/src/argilla/tasks/server/database/users/create_default.py +++ b/src/argilla/tasks/server/database/users/create_default.py @@ -18,7 +18,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import User, UserRole, Workspace -from argilla.tasks import async_typer +from argilla.tasks import typer_ext async def create_default( @@ -53,4 +53,4 @@ async def create_default( if __name__ == "__main__": - async_typer.run(create_default) + typer_ext.run(create_default) diff --git a/src/argilla/tasks/server/database/users/migrate.py b/src/argilla/tasks/server/database/users/migrate.py index 6efefd6c8b..17339df71d 100644 --- a/src/argilla/tasks/server/database/users/migrate.py +++ b/src/argilla/tasks/server/database/users/migrate.py @@ -22,7 +22,7 @@ from argilla.server.models import User, UserRole from argilla.server.security.auth_provider.local.settings import settings from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database.users.utils import get_or_new_workspace if TYPE_CHECKING: @@ -111,4 +111,4 @@ async def migrate(): if __name__ == "__main__": - async_typer.run(migrate) + typer_ext.run(migrate) diff --git a/src/argilla/tasks/server/database/users/update.py b/src/argilla/tasks/server/database/users/update.py index 8882779ab4..d9d26a700c 100644 --- a/src/argilla/tasks/server/database/users/update.py +++ b/src/argilla/tasks/server/database/users/update.py @@ -17,7 +17,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import UserRole -from argilla.tasks import async_typer +from argilla.tasks import typer_ext async def update( @@ -52,4 +52,4 @@ async def update( if __name__ == "__main__": - async_typer.run(update) + typer_ext.run(update) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/typer_ext.py similarity index 64% rename from src/argilla/tasks/async_typer.py rename to src/argilla/tasks/typer_ext.py index cac3855047..701b5830fa 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/typer_ext.py @@ -15,7 +15,7 @@ import asyncio import sys from functools import wraps -from typing import Any, Callable, Coroutine, TypeVar +from typing import Any, Callable, Coroutine, Dict, Type, TypeVar import typer @@ -28,9 +28,13 @@ P = ParamSpec("P") R = TypeVar("R") +HandleErrorFunc = Callable[[Exception], None] -# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 -class AsyncTyper(typer.Typer): + +class ArgillaTyper(typer.Typer): + error_handlers: Dict[Type[Exception], HandleErrorFunc] = {} + + # https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 def command( self, *args: Any, **kwargs: Any ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: @@ -50,8 +54,25 @@ def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: return decorator + def error_handler(self, exc: Type[Exception]) -> Callable[[HandleErrorFunc], None]: + def decorator(func: HandleErrorFunc) -> None: + self.error_handlers[exc] = func + + return decorator + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + try: + return super().__call__(*args, **kwargs) + except typer.Exit as e: + raise e + except Exception as e: + handler = self.error_handlers.get(type(e)) + if handler is None: + raise e + handler(e) + def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None: - app = AsyncTyper(add_completion=False) + app = ArgillaTyper(add_completion=False) app.command()(function) app() diff --git a/tests/unit/tasks/conftest.py b/tests/unit/tasks/conftest.py index 8d0d0a9342..64d6a6d1cc 100644 --- a/tests/unit/tasks/conftest.py +++ b/tests/unit/tasks/conftest.py @@ -32,7 +32,7 @@ from tests.database import SyncTestSession if TYPE_CHECKING: - from argilla.tasks.async_typer import AsyncTyper + from argilla.tasks.typer_ext import ArgillaTyper from pytest_mock import MockerFixture from sqlalchemy.engine import Connection from sqlalchemy.orm import Session @@ -44,7 +44,7 @@ def cli_runner() -> CliRunner: @pytest.fixture(scope="session") -def cli() -> "AsyncTyper": +def cli() -> "ArgillaTyper": return app