From edb0a6648337a0ec3ed661bc0362e13f92c872e9 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 12:29:25 +0200 Subject: [PATCH 1/7] feat: add `error_handler` method --- src/argilla/tasks/async_typer.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/async_typer.py index cac3855047..701b5830fa 100644 --- a/src/argilla/tasks/async_typer.py +++ b/src/argilla/tasks/async_typer.py @@ -15,7 +15,7 @@ import asyncio import sys from functools import wraps -from typing import Any, Callable, Coroutine, TypeVar +from typing import Any, Callable, Coroutine, Dict, Type, TypeVar import typer @@ -28,9 +28,13 @@ P = ParamSpec("P") R = TypeVar("R") +HandleErrorFunc = Callable[[Exception], None] -# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 -class AsyncTyper(typer.Typer): + +class ArgillaTyper(typer.Typer): + error_handlers: Dict[Type[Exception], HandleErrorFunc] = {} + + # https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597 def command( self, *args: Any, **kwargs: Any ) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: @@ -50,8 +54,25 @@ def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R: return decorator + def error_handler(self, exc: Type[Exception]) -> Callable[[HandleErrorFunc], None]: + def decorator(func: HandleErrorFunc) -> None: + self.error_handlers[exc] = func + + return decorator + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + try: + return super().__call__(*args, **kwargs) + except typer.Exit as e: + raise e + except Exception as e: + handler = self.error_handlers.get(type(e)) + if handler is None: + raise e + handler(e) + def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None: - app = AsyncTyper(add_completion=False) + app = ArgillaTyper(add_completion=False) app.command()(function) app() From 6a752474d12e4633432eecadeeeb5d41f90f69f1 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 13:16:38 +0200 Subject: [PATCH 2/7] refactor: rename module to `typer_ext` --- src/argilla/tasks/server/database/users/__main__.py | 4 ++-- src/argilla/tasks/{async_typer.py => typer_ext.py} | 0 tests/unit/tasks/conftest.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/argilla/tasks/{async_typer.py => typer_ext.py} (100%) diff --git a/src/argilla/tasks/server/database/users/__main__.py b/src/argilla/tasks/server/database/users/__main__.py index 7f0ac33050..85362237dd 100644 --- a/src/argilla/tasks/server/database/users/__main__.py +++ b/src/argilla/tasks/server/database/users/__main__.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla.tasks.async_typer import AsyncTyper +from argilla.tasks.typer_ext import ArgillaTyper from .create import create from .create_default import create_default from .migrate import migrate from .update import update -app = AsyncTyper( +app = ArgillaTyper( help="CLI commands for user and workspace management using the database connection", no_args_is_help=True ) diff --git a/src/argilla/tasks/async_typer.py b/src/argilla/tasks/typer_ext.py similarity index 100% rename from src/argilla/tasks/async_typer.py rename to src/argilla/tasks/typer_ext.py diff --git a/tests/unit/tasks/conftest.py b/tests/unit/tasks/conftest.py index 8d0d0a9342..64d6a6d1cc 100644 --- a/tests/unit/tasks/conftest.py +++ b/tests/unit/tasks/conftest.py @@ -32,7 +32,7 @@ from tests.database import SyncTestSession if TYPE_CHECKING: - from argilla.tasks.async_typer import AsyncTyper + from argilla.tasks.typer_ext import ArgillaTyper from pytest_mock import MockerFixture from sqlalchemy.engine import Connection from sqlalchemy.orm import Session @@ -44,7 +44,7 @@ def cli_runner() -> CliRunner: @pytest.fixture(scope="session") -def cli() -> "AsyncTyper": +def cli() -> "ArgillaTyper": return app From d3fd452dedacb7b121c2e28514bd70f78ab5ecd2 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 13:17:11 +0200 Subject: [PATCH 3/7] feat: add `PermissionError` error handler --- src/argilla/tasks/app.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/argilla/tasks/app.py b/src/argilla/tasks/app.py index f23bfa0a57..17dc7435cf 100644 --- a/src/argilla/tasks/app.py +++ b/src/argilla/tasks/app.py @@ -14,6 +14,8 @@ import warnings +import typer + from argilla.tasks import ( datasets_app, info_app, @@ -24,12 +26,28 @@ whoami_app, workspaces_app, ) -from argilla.tasks.async_typer import AsyncTyper +from argilla.tasks.typer_ext import ArgillaTyper from argilla.utils.dependency import is_package_with_extras_installed warnings.simplefilter("ignore", UserWarning) -app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) +app = ArgillaTyper(help="Argilla CLI", no_args_is_help=True) + + +@app.error_handler(PermissionError) +def handler_permission_error(e: PermissionError) -> None: + from rich.console import Console + + from argilla.tasks.rich import get_argilla_themed_panel + + panel = get_argilla_themed_panel( + "Logged in user doesn't have enough permissions to execute this command.", + title="Not enough permissions", + success=False, + ) + + Console().print(panel) + app.add_typer(datasets_app, name="datasets") app.add_typer(info_app, name="info") From af421b4078ea5099c9851a5ee8a8f19582ccda51 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 13:21:42 +0200 Subject: [PATCH 4/7] fix: `ImportError` after rename --- src/argilla/tasks/server/database/migrate.py | 4 ++-- src/argilla/tasks/server/database/revisions.py | 4 ++-- src/argilla/tasks/server/database/users/create.py | 4 ++-- src/argilla/tasks/server/database/users/create_default.py | 4 ++-- src/argilla/tasks/server/database/users/migrate.py | 4 ++-- src/argilla/tasks/server/database/users/update.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/argilla/tasks/server/database/migrate.py b/src/argilla/tasks/server/database/migrate.py index c73555dc2b..52112b87a3 100644 --- a/src/argilla/tasks/server/database/migrate.py +++ b/src/argilla/tasks/server/database/migrate.py @@ -21,7 +21,7 @@ from alembic.util import CommandError from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database import utils @@ -49,4 +49,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R if __name__ == "__main__": - async_typer.run(migrate_db) + typer_ext.run(migrate_db) diff --git a/src/argilla/tasks/server/database/revisions.py b/src/argilla/tasks/server/database/revisions.py index 922184abb0..eb2f4c99e8 100644 --- a/src/argilla/tasks/server/database/revisions.py +++ b/src/argilla/tasks/server/database/revisions.py @@ -16,7 +16,7 @@ import typer from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database import utils @@ -41,4 +41,4 @@ def revisions(): if __name__ == "__main__": - async_typer.run(revisions) + typer_ext.run(revisions) diff --git a/src/argilla/tasks/server/database/users/create.py b/src/argilla/tasks/server/database/users/create.py index f08d307855..3dbedc9ed9 100644 --- a/src/argilla/tasks/server/database/users/create.py +++ b/src/argilla/tasks/server/database/users/create.py @@ -25,7 +25,7 @@ UserCreate, WorkspaceCreate, ) -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database.users.utils import get_or_new_workspace USER_API_KEY_MIN_LENGTH = 8 @@ -130,4 +130,4 @@ async def create( if __name__ == "__main__": - async_typer.run(create) + typer_ext.run(create) diff --git a/src/argilla/tasks/server/database/users/create_default.py b/src/argilla/tasks/server/database/users/create_default.py index 88c0caff0f..df4cce60ba 100644 --- a/src/argilla/tasks/server/database/users/create_default.py +++ b/src/argilla/tasks/server/database/users/create_default.py @@ -18,7 +18,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import User, UserRole, Workspace -from argilla.tasks import async_typer +from argilla.tasks import typer_ext async def create_default( @@ -53,4 +53,4 @@ async def create_default( if __name__ == "__main__": - async_typer.run(create_default) + typer_ext.run(create_default) diff --git a/src/argilla/tasks/server/database/users/migrate.py b/src/argilla/tasks/server/database/users/migrate.py index 6efefd6c8b..17339df71d 100644 --- a/src/argilla/tasks/server/database/users/migrate.py +++ b/src/argilla/tasks/server/database/users/migrate.py @@ -22,7 +22,7 @@ from argilla.server.models import User, UserRole from argilla.server.security.auth_provider.local.settings import settings from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX -from argilla.tasks import async_typer +from argilla.tasks import typer_ext from argilla.tasks.server.database.users.utils import get_or_new_workspace if TYPE_CHECKING: @@ -111,4 +111,4 @@ async def migrate(): if __name__ == "__main__": - async_typer.run(migrate) + typer_ext.run(migrate) diff --git a/src/argilla/tasks/server/database/users/update.py b/src/argilla/tasks/server/database/users/update.py index 8882779ab4..d9d26a700c 100644 --- a/src/argilla/tasks/server/database/users/update.py +++ b/src/argilla/tasks/server/database/users/update.py @@ -17,7 +17,7 @@ from argilla.server.contexts import accounts from argilla.server.database import AsyncSessionLocal from argilla.server.models import UserRole -from argilla.tasks import async_typer +from argilla.tasks import typer_ext async def update( @@ -52,4 +52,4 @@ async def update( if __name__ == "__main__": - async_typer.run(update) + typer_ext.run(update) From 9d9623e0ed9cc3713422af1b90b099cb2f431435 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 13:24:50 +0200 Subject: [PATCH 5/7] feat: add `success` argument --- src/argilla/tasks/rich.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/argilla/tasks/rich.py b/src/argilla/tasks/rich.py index 46ed701f9c..851c1e5ba1 100644 --- a/src/argilla/tasks/rich.py +++ b/src/argilla/tasks/rich.py @@ -28,5 +28,7 @@ def get_argilla_themed_table(title: str, **kwargs: Any) -> Table: return Table(title=title, border_style=_ARGILLA_BORDER_STYLE, **kwargs) -def get_argilla_themed_panel(renderable: "RenderableType", **kwargs: Any) -> Panel: - return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, **kwargs) +def get_argilla_themed_panel(renderable: "RenderableType", title: str, success: bool = True, **kwargs: Any) -> Panel: + if success: + title = f"[green]{title}" + return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, title=title, **kwargs) From 3eaad962d12428b4f38c931a6055ce3c7e8981a8 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 14:04:30 +0200 Subject: [PATCH 6/7] docs: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8beba27d14..9b240a5101 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ These are the section headers that we use: - Added `info` command to get info about the used Argilla client and server ([#3707](https://github.com/argilla-io/argilla/pull/3707)). - Added `datasets delete` command to delete a `FeedbackDataset` from Argilla ([#3703](https://github.com/argilla-io/argilla/pull/3703)). - Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)). +- Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)). ### Changed From 0ff36b5375d3585905c2a5f58fae58f03fd4d0f2 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 6 Sep 2023 14:09:32 +0200 Subject: [PATCH 7/7] feat: add exit code --- src/argilla/tasks/app.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/argilla/tasks/app.py b/src/argilla/tasks/app.py index 17dc7435cf..351a55758b 100644 --- a/src/argilla/tasks/app.py +++ b/src/argilla/tasks/app.py @@ -14,8 +14,6 @@ import warnings -import typer - from argilla.tasks import ( datasets_app, info_app, @@ -36,17 +34,21 @@ @app.error_handler(PermissionError) def handler_permission_error(e: PermissionError) -> None: + import sys + from rich.console import Console from argilla.tasks.rich import get_argilla_themed_panel panel = get_argilla_themed_panel( - "Logged in user doesn't have enough permissions to execute this command.", + "Logged in user doesn't have enough permissions to execute this command", title="Not enough permissions", + title_align="left", success=False, ) Console().print(panel) + sys.exit(1) app.add_typer(datasets_app, name="datasets")