From 3c9b5d0d15fd805eda82b57248f8cf7b3853d4d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 30 Aug 2023 14:06:06 +0200 Subject: [PATCH] feat: add `datasets list` command (#3658) # Description This PR adds a new command `python -m argilla datasets list` which allows to list the datasets from the logged Argilla server in the terminal. image Closes #3591 **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** I've created a custom environment, used the `python -m argilla login` command to login in it and then list the datasets from this environment using the new command. All the datasets were listed. Applying workspace filter worked. Applying dataset kind filter worked. Additionally, I've added unit tests to cover all the additions. **Checklist** - [ ] I added relevant documentation - [x] I followed the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: alvarobartt --- CHANGELOG.md | 1 + src/argilla/__main__.py | 3 +- src/argilla/tasks/__init__.py | 1 + src/argilla/tasks/datasets/__init__.py | 18 +++ src/argilla/tasks/datasets/__main__.py | 29 +++++ src/argilla/tasks/datasets/list.py | 74 +++++++++++ src/argilla/tasks/rich.py | 22 ++++ src/argilla/tasks/workspaces/list.py | 12 +- tests/unit/tasks/conftest.py | 7 +- tests/unit/tasks/datasets/__init__.py | 13 ++ tests/unit/tasks/datasets/test_list.py | 155 +++++++++++++++++++++++ tests/unit/tasks/workspaces/test_list.py | 5 +- 12 files changed, 329 insertions(+), 11 deletions(-) create mode 100644 src/argilla/tasks/datasets/__init__.py create mode 100644 src/argilla/tasks/datasets/__main__.py create mode 100644 src/argilla/tasks/datasets/list.py create mode 100644 src/argilla/tasks/rich.py create mode 100644 tests/unit/tasks/datasets/__init__.py create mode 100644 tests/unit/tasks/datasets/test_list.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8dfbe20e..8cfa2b1ba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ These are the section headers that we use: - Added `list_workspaces` function (to be used as `rg.list_workspaces`, but `Workspace.list` is preferred) to list all the workspaces from an user in Argilla ([#3641](https://github.com/argilla-io/argilla/pull/3641)). - Added `list_datasets` function (to be used as `rg.list_datasets`) to list the `TextClassification`, `TokenClassification`, and `Text2Text` datasets in Argilla ([#3638](https://github.com/argilla-io/argilla/pull/3638)). - Added `workspaces list` command to list Argilla workspaces ([#3594](https://github.com/argilla-io/argilla/pull/3594)). +- Added `datasets list` command to list Argilla datasets ([#3658](https://github.com/argilla-io/argilla/pull/3658)). ### Changed diff --git a/src/argilla/__main__.py b/src/argilla/__main__.py index 382749a56d..236951f5ab 100644 --- a/src/argilla/__main__.py +++ b/src/argilla/__main__.py @@ -14,12 +14,13 @@ # limitations under the License. -from argilla.tasks import database_app, login_app, logout_app, server_app, training_app, workspaces_app +from argilla.tasks import database_app, datasets_app, login_app, logout_app, server_app, training_app, workspaces_app from argilla.tasks.async_typer import AsyncTyper app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) app.add_typer(database_app, name="database") +app.add_typer(datasets_app, name="datasets") app.add_typer(login_app, name="login") app.add_typer(logout_app, name="logout") app.add_typer(server_app, name="server") diff --git a/src/argilla/tasks/__init__.py b/src/argilla/tasks/__init__.py index 257a7313af..6d089d4e80 100644 --- a/src/argilla/tasks/__init__.py +++ b/src/argilla/tasks/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .database import app as database_app +from .datasets import app as datasets_app from .login import app as login_app from .logout import app as logout_app from .server import app as server_app diff --git a/src/argilla/tasks/datasets/__init__.py b/src/argilla/tasks/datasets/__init__.py new file mode 100644 index 0000000000..b0ad568f34 --- /dev/null +++ b/src/argilla/tasks/datasets/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .__main__ import app + +if __name__ == "__main__": + app() diff --git a/src/argilla/tasks/datasets/__main__.py b/src/argilla/tasks/datasets/__main__.py new file mode 100644 index 0000000000..406bbd0ac6 --- /dev/null +++ b/src/argilla/tasks/datasets/__main__.py @@ -0,0 +1,29 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typer + +from argilla.tasks.callback import init_callback +from argilla.tasks.datasets.list import list_datasets + +app = typer.Typer( + help="Holds CLI commands for datasets management", invoke_without_command=True, callback=init_callback +) + + +app.command(name="list", help="List datasets linked to user's workspaces")(list_datasets) + + +if __name__ == "__main__": + app() diff --git a/src/argilla/tasks/datasets/list.py b/src/argilla/tasks/datasets/list.py new file mode 100644 index 0000000000..e0f33f61e2 --- /dev/null +++ b/src/argilla/tasks/datasets/list.py @@ -0,0 +1,74 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, Optional + +import typer + + +class DatasetType(str, Enum): + feedback = "feedback" + other = "other" + + +def list_datasets( + workspace: Optional[str] = typer.Option(None, help="List datasets in this workspace"), + type_: Optional[DatasetType] = typer.Option( + None, + "--type", + help="The type of datasets to be listed. This option can be used multiple times. By default, all datasets are listed.", + ), +) -> None: + from rich.console import Console + from rich.markdown import Markdown + + from argilla.client.api import list_datasets as list_datasets_api + from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.tasks.rich import get_argilla_themed_table + + def build_tags_text(tags: Dict[str, str]) -> Markdown: + text = "" + for tag, description in tags.items(): + text += f"- **{tag}**: {description}\n" + return Markdown(text) + + table = get_argilla_themed_table(title="Datasets") + for column in ("ID", "Name", "Workspace", "Type", "Tags", "Creation Date", "Last Update Date"): + table.add_column(column, justify="center") + + if type_ is None or type_ == DatasetType.feedback: + for dataset in FeedbackDataset.list(workspace): + # TODO: add passing value for `Creation Date` and `Update Date` columns once `RemoteFeedbackDataset` has + # these attributes + table.add_row(str(dataset.id), dataset.name, dataset.workspace.name, "Feedback", None, None, None) + + if type_ is None or type_ == DatasetType.other: + for dataset in list_datasets_api(workspace): + table.add_row( + dataset.id, + dataset.name, + dataset.workspace, + dataset.task, + build_tags_text(dataset.tags), + dataset.created_at.isoformat(sep=" "), + dataset.last_updated.isoformat(sep=" "), + ) + + console = Console() + console.print(table) + + +if __name__ == "__main__": + typer.run(list_datasets) diff --git a/src/argilla/tasks/rich.py b/src/argilla/tasks/rich.py new file mode 100644 index 0000000000..6a5053f67d --- /dev/null +++ b/src/argilla/tasks/rich.py @@ -0,0 +1,22 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from rich.table import Table + + +def get_argilla_themed_table(title: str, **kwargs: Any) -> Table: + # TODO: update colors after consulting it with UI expert + return Table(title=title, border_style="red", **kwargs) diff --git a/src/argilla/tasks/workspaces/list.py b/src/argilla/tasks/workspaces/list.py index b1b10015b0..791ff6cbaf 100644 --- a/src/argilla/tasks/workspaces/list.py +++ b/src/argilla/tasks/workspaces/list.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -from argilla.tasks import async_typer +import typer def list_workspaces() -> None: """List the workspaces in Argilla and prints them on the console.""" from rich.console import Console - from rich.table import Table from argilla import Workspace + from argilla.tasks.rich import get_argilla_themed_table workspaces = Workspace.list() - table = Table(title="Workspaces") - for column in ("ID", "Name", "Creation Date", "Update Date"): + table = get_argilla_themed_table(title="Workspaces") + for column in ("ID", "Name", "Creation Date", "Last Update Date"): table.add_column(column, justify="center") for workspace in workspaces: @@ -43,4 +41,4 @@ def list_workspaces() -> None: if __name__ == "__main__": - async_typer.run(list_workspaces) + typer.run(list_workspaces) diff --git a/tests/unit/tasks/conftest.py b/tests/unit/tasks/conftest.py index 998d85bb97..7a87be338a 100644 --- a/tests/unit/tasks/conftest.py +++ b/tests/unit/tasks/conftest.py @@ -89,6 +89,11 @@ def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession @pytest.fixture -def login_mock(mocker: "MockerFixture"): +def login_mock(mocker: "MockerFixture") -> None: mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=True) mocker.patch("argilla.client.api.ArgillaSingleton.init") + + +@pytest.fixture +def not_logged_mock(mocker: "MockerFixture") -> None: + mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=False) diff --git a/tests/unit/tasks/datasets/__init__.py b/tests/unit/tasks/datasets/__init__.py new file mode 100644 index 0000000000..55be41799b --- /dev/null +++ b/tests/unit/tasks/datasets/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/tasks/datasets/test_list.py b/tests/unit/tasks/datasets/test_list.py new file mode 100644 index 0000000000..4a1d24e942 --- /dev/null +++ b/tests/unit/tasks/datasets/test_list.py @@ -0,0 +1,155 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import TYPE_CHECKING +from unittest.mock import ANY, call +from uuid import uuid4 + +import httpx +import pytest +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas.fields import TextField +from argilla.client.feedback.schemas.questions import TextQuestion +from argilla.client.sdk.datasets.models import Dataset +from argilla.client.workspaces import Workspace +from rich.table import Table + +if TYPE_CHECKING: + from click.testing import CliRunner + from pytest_mock import MockerFixture + from typer import Typer + + +@pytest.fixture +def remote_feedback_dataset() -> "RemoteFeedbackDataset": + workspace = Workspace.__new__(Workspace) + workspace.__dict__.update( + { + "id": uuid4(), + "name": "unit-test", + "inserted_at": datetime.now(), + "updated_at": datetime.now(), + } + ) + return RemoteFeedbackDataset( + client=httpx.Client(), + id=uuid4(), + name="unit-test", + workspace=workspace, + fields=[TextField(name="prompt")], + questions=[TextQuestion(name="corrected")], + ) + + +@pytest.fixture +def dataset() -> Dataset: + return Dataset( + name="unit-test", + id="rg.unit-test", + task="TextClassification", + owner="unit-test", + workspace="unit-test", + created_at=datetime.now(), + last_updated=datetime.now(), + ) + + +@pytest.mark.usefixtures("login_mock") +class TestSuiteListDatasetsCommand: + def test_list_datasets( + self, + cli_runner: "CliRunner", + cli: "Typer", + mocker: "MockerFixture", + remote_feedback_dataset: RemoteFeedbackDataset, + dataset: Dataset, + ) -> None: + add_row_spy = mocker.spy(Table, "add_row") + feedback_dataset_list_mock = mocker.patch( + "argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] + ) + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) + + result = cli_runner.invoke(cli, "datasets list") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with(None) + list_datasets_mock.assert_called_once_with(None) + add_row_spy.assert_has_calls( + [ + call( + ANY, + str(remote_feedback_dataset.id), + remote_feedback_dataset.name, + remote_feedback_dataset.workspace.name, + "Feedback", + None, + None, + None, + ), + call( + ANY, + str(dataset.id), + dataset.name, + "unit-test", + "TextClassification", + ANY, + dataset.created_at.isoformat(sep=" "), + dataset.last_updated.isoformat(sep=" "), + ), + ] + ) + + def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --workspace unit-test") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with("unit-test") + list_datasets_mock.assert_called_once_with("unit-test") + + def test_list_datasets_using_type_feedback_filter( + self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" + ) -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --type feedback") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with(None) + list_datasets_mock.assert_not_called() + + def test_list_datasets_using_type_other_filter( + self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" + ) -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --type other") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_not_called() + list_datasets_mock.assert_called_once_with(None) + + +@pytest.mark.usefixtures("not_logged_mock") +def test_cli_datasets_list_needs_login(cli_runner: "CliRunner", cli: "Typer") -> None: + result = cli_runner.invoke(cli, "datasets list") + + assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout + assert result.exit_code == 1 diff --git a/tests/unit/tasks/workspaces/test_list.py b/tests/unit/tasks/workspaces/test_list.py index e660682305..89c0962f7c 100644 --- a/tests/unit/tasks/workspaces/test_list.py +++ b/tests/unit/tasks/workspaces/test_list.py @@ -53,11 +53,12 @@ def test_cli_workspaces_list(cli_runner: "CliRunner", cli: "Typer", mocker: "Moc workspace.inserted_at.isoformat(sep=" "), workspace.updated_at.isoformat(sep=" "), ) - assert all(col in result.stdout for col in ("ID", "Name", "Creation Date", "Update Date", "test_workspace")) + assert all(col in result.stdout for col in ("ID", "Name", "Creation Date", "Last Update Date", "test_workspace")) assert result.exit_code == 0 -def test_cli_workspaces_list_needs_login(cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"): +@pytest.mark.usefixtures("not_logged_mock") +def test_cli_workspaces_list_needs_login(cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "workspaces list") assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout