Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle PermissionError in CLI #3717

Merged
merged 8 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ These are the section headers that we use:
- Added `info` command to get info about the used Argilla client and server ([#3707](https://github.com/argilla-io/argilla/pull/3707)).
- Added `datasets delete` command to delete a `FeedbackDataset` from Argilla ([#3703](https://github.com/argilla-io/argilla/pull/3703)).
- Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)).
- Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)).
- Added `workspaces add-user` command to add a user to workspace ([#3712](https://github.com/argilla-io/argilla/pull/3712)).

### Changed
Expand Down
24 changes: 22 additions & 2 deletions src/argilla/tasks/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,32 @@
whoami_app,
workspaces_app,
)
from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper
from argilla.utils.dependency import is_package_with_extras_installed

warnings.simplefilter("ignore", UserWarning)

app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)
app = ArgillaTyper(help="Argilla CLI", no_args_is_help=True)


@app.error_handler(PermissionError)
def handler_permission_error(e: PermissionError) -> None:
import sys

from rich.console import Console

from argilla.tasks.rich import get_argilla_themed_panel

panel = get_argilla_themed_panel(
"Logged in user doesn't have enough permissions to execute this command",
title="Not enough permissions",
title_align="left",
success=False,
)

Console().print(panel)
sys.exit(1)


app.add_typer(datasets_app, name="datasets")
app.add_typer(info_app, name="info")
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/tasks/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ def get_argilla_themed_table(title: str, **kwargs: Any) -> Table:
return Table(title=title, border_style=_ARGILLA_BORDER_STYLE, **kwargs)


def get_argilla_themed_panel(renderable: "RenderableType", **kwargs: Any) -> Panel:
return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, **kwargs)
def get_argilla_themed_panel(renderable: "RenderableType", title: str, success: bool = True, **kwargs: Any) -> Panel:
if success:
title = f"[green]{title}"
return Panel(renderable=renderable, border_style=_ARGILLA_BORDER_STYLE, title=title, **kwargs)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from alembic.util import CommandError

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database import utils


Expand Down Expand Up @@ -49,4 +49,4 @@ def migrate_db(revision: Optional[str] = typer.Option(default="head", help="DB R


if __name__ == "__main__":
async_typer.run(migrate_db)
typer_ext.run(migrate_db)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/revisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import typer

from argilla.server.database import ALEMBIC_CONFIG_FILE, TAGGED_REVISIONS
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database import utils


Expand All @@ -41,4 +41,4 @@ def revisions():


if __name__ == "__main__":
async_typer.run(revisions)
typer_ext.run(revisions)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper

from .create import create
from .create_default import create_default
from .migrate import migrate
from .update import update

app = AsyncTyper(
app = ArgillaTyper(
help="CLI commands for user and workspace management using the database connection", no_args_is_help=True
)

Expand Down
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
UserCreate,
WorkspaceCreate,
)
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database.users.utils import get_or_new_workspace

USER_API_KEY_MIN_LENGTH = 8
Expand Down Expand Up @@ -130,4 +130,4 @@ async def create(


if __name__ == "__main__":
async_typer.run(create)
typer_ext.run(create)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/create_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argilla.server.contexts import accounts
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import User, UserRole, Workspace
from argilla.tasks import async_typer
from argilla.tasks import typer_ext


async def create_default(
Expand Down Expand Up @@ -53,4 +53,4 @@ async def create_default(


if __name__ == "__main__":
async_typer.run(create_default)
typer_ext.run(create_default)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argilla.server.models import User, UserRole
from argilla.server.security.auth_provider.local.settings import settings
from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX
from argilla.tasks import async_typer
from argilla.tasks import typer_ext
from argilla.tasks.server.database.users.utils import get_or_new_workspace

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,4 +111,4 @@ async def migrate():


if __name__ == "__main__":
async_typer.run(migrate)
typer_ext.run(migrate)
4 changes: 2 additions & 2 deletions src/argilla/tasks/server/database/users/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argilla.server.contexts import accounts
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import UserRole
from argilla.tasks import async_typer
from argilla.tasks import typer_ext


async def update(
Expand Down Expand Up @@ -52,4 +52,4 @@ async def update(


if __name__ == "__main__":
async_typer.run(update)
typer_ext.run(update)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import asyncio
import sys
from functools import wraps
from typing import Any, Callable, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, Dict, Type, TypeVar

import typer

Expand All @@ -28,9 +28,13 @@
P = ParamSpec("P")
R = TypeVar("R")

HandleErrorFunc = Callable[[Exception], None]

# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597
class AsyncTyper(typer.Typer):

class ArgillaTyper(typer.Typer):
error_handlers: Dict[Type[Exception], HandleErrorFunc] = {}

# https://github.com/tiangolo/typer/issues/88#issuecomment-1613013597
def command(
self, *args: Any, **kwargs: Any
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]:
Expand All @@ -50,8 +54,25 @@ def sync_func(*_args: P.args, **_kwargs: P.kwargs) -> R:

return decorator

def error_handler(self, exc: Type[Exception]) -> Callable[[HandleErrorFunc], None]:
def decorator(func: HandleErrorFunc) -> None:
self.error_handlers[exc] = func

return decorator

def __call__(self, *args: Any, **kwargs: Any) -> Any:
try:
return super().__call__(*args, **kwargs)
except typer.Exit as e:
raise e
except Exception as e:
handler = self.error_handlers.get(type(e))
if handler is None:
raise e
handler(e)


def run(function: Callable[..., Coroutine[Any, Any, Any]]) -> None:
app = AsyncTyper(add_completion=False)
app = ArgillaTyper(add_completion=False)
app.command()(function)
app()
4 changes: 2 additions & 2 deletions tests/unit/tasks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tests.database import SyncTestSession

if TYPE_CHECKING:
from argilla.tasks.async_typer import AsyncTyper
from argilla.tasks.typer_ext import ArgillaTyper
from pytest_mock import MockerFixture
from sqlalchemy.engine import Connection
from sqlalchemy.orm import Session
Expand All @@ -44,7 +44,7 @@ def cli_runner() -> CliRunner:


@pytest.fixture(scope="session")
def cli() -> "AsyncTyper":
def cli() -> "ArgillaTyper":
return app


Expand Down
Loading