Skip to content

Commit

Permalink
feat: handle PermissionError in CLI (#3717)
Browse files Browse the repository at this point in the history
# Description

This PR adds a way to handle the `PermissionError`s that could be raised
when executing a command with a logged in user that doesn't have the
required permissions because of his role. In addition, the module
`argilla.tasks.async_typer` has been renamed to
`argilla.tasks.typer_ext` (Typer extension) and the `AsyncTyper` class
has been renamed to `ArgillaTyper`.

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

In a local development environment:

- [x] Login with an API Key linked to an annotator account
- [x] Trying to execute a command that required to be owner or admin
printed the message describing that the logged in user doesn't have
enough permissions

**Checklist**

- [ ] I added relevant documentation
- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
gabrielmbmb authored Sep 6, 2023
1 parent b7338cb commit dab26b4
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ These are the section headers that we use:
- Added `info` command to get info about the used Argilla client and server ([#3707](https://github.com/argilla-io/argilla/pull/3707)).
- Added `datasets delete` command to delete a `FeedbackDataset` from Argilla ([#3703](https://github.com/argilla-io/argilla/pull/3703)).
- Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)).
- Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)).
- Added `workspaces add-user` command to add a user to workspace ([#3712](https://github.com/argilla-io/argilla/pull/3712)).

### Changed
Expand Down
24 changes: 22 additions & 2 deletions src/argilla/tasks/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,32 @@
whoami_app,
workspaces_app,
)
from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper
from argilla.utils.dependency import is_package_with_extras_installed

warnings.simplefilter("ignore", UserWarning)

app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)
app = ArgillaTyper(help="Argilla CLI", no_args_is_help=True)


@app.error_handler(PermissionError)
def handler_permission_error(e: PermissionError) -> None:
import sys

from rich.console import Console

from argilla.tasks.rich import get_argilla_themed_panel

panel = get_argilla_themed_panel(
"Logged in user doesn't have enough permissions to execute this command",
title="Not enough permissions",
title_align="left",
success=False,
)

Console().print(panel)
sys.exit(1)


app.add_typer(datasets_app, name="datasets")
app.add_typer(info_app, name="info")
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/tasks/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ def get_argilla_themed_table(title: str, **kwargs: Any) -> Table:
return Table(title=title, border_style=_ARGILLA_BORDER_STYLE, **kwargs)


def get_argilla_themed_panel(renderable: "RenderableType", **kwargs: Any) -> Panel:
return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, **kwargs)
def get_argilla_themed_panel(renderable: "RenderableType", title: str, success: bool = True, **kwargs: Any) -> Panel:
if success:
title = f"[green]{title}"
return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, title=title, **kwargs)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from alembic.util import CommandError

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database import utils


Expand Down Expand Up @@ -49,4 +49,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R


if __name__ == "__main__":
async_typer.run(migrate_db)
typer_ext.run(migrate_db)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import typer

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database import utils


Expand All @@ -41,4 +41,4 @@ def revisions():


if __name__ == "__main__":
async_typer.run(revisions)
typer_ext.run(revisions)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper

from .create import create
from .create_default import create_default
from .migrate import migrate
from .update import update

app = AsyncTyper(
app = ArgillaTyper(
help="CLI commands for user and workspace management using the database connection", no_args_is_help=True
)

Expand Down
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
UserCreate,
WorkspaceCreate,
)
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database.users.utils import get_or_new_workspace

USER_API_KEY_MIN_LENGTH = 8
Expand Down Expand Up @@ -130,4 +130,4 @@ async def create(


if __name__ == "__main__":
async_typer.run(create)
typer_ext.run(create)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/create_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argilla.server.contexts import accounts
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import User, UserRole, Workspace
from argilla.tasks import async_typer
from argilla.tasks import typer_ext


async def create_default(
Expand Down Expand Up @@ -53,4 +53,4 @@ async def create_default(


if __name__ == "__main__":
async_typer.run(create_default)
typer_ext.run(create_default)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argilla.server.models import User, UserRole
from argilla.server.security.auth_provider.local.settings import settings
from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database.users.utils import get_or_new_workspace

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,4 +111,4 @@ async def migrate():


if __name__ == "__main__":
async_typer.run(migrate)
typer_ext.run(migrate)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argilla.server.contexts import accounts
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import UserRole
from argilla.tasks import async_typer
from argilla.tasks import typer_ext


async def update(
Expand Down Expand Up @@ -52,4 +52,4 @@ async def update(


if __name__ == "__main__":
async_typer.run(update)
typer_ext.run(update)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import asyncio
import sys
from functools import wraps
from typing import Any, Callable, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, Dict, Type, TypeVar

import typer

Expand All @@ -28,9 +28,13 @@
P = ParamSpec("P")
R = TypeVar("R")

HandleErrorFunc = Callable[[Exception], None]

# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597
class AsyncTyper(typer.Typer):

class ArgillaTyper(typer.Typer):
error_handlers: Dict[Type[Exception], HandleErrorFunc] = {}

# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597
def command(
self, *args: Any, **kwargs: Any
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]:
Expand All @@ -50,8 +54,25 @@ def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R:

return decorator

def error_handler(self, exc: Type[Exception]) -> Callable[[HandleErrorFunc], None]:
def decorator(func: HandleErrorFunc) -> None:
self.error_handlers[exc] = func

return decorator

def __call__(self, *args: Any, **kwargs: Any) -> Any:
try:
return super().__call__(*args, **kwargs)
except typer.Exit as e:
raise e
except Exception as e:
handler = self.error_handlers.get(type(e))
if handler is None:
raise e
handler(e)


def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None:
app = AsyncTyper(add_completion=False)
app = ArgillaTyper(add_completion=False)
app.command()(function)
app()
4 changes: 2 additions & 2 deletions tests/unit/tasks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tests.database import SyncTestSession

if TYPE_CHECKING:
from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper
from pytest_mock import MockerFixture
from sqlalchemy.engine import Connection
from sqlalchemy.orm import Session
Expand All @@ -44,7 +44,7 @@ def cli_runner() -> CliRunner:


@pytest.fixture(scope="session")
def cli() -> "AsyncTyper":
def cli() -> "ArgillaTyper":
return app


Expand Down

0 comments on commit dab26b4

Please sign in to comment.