diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8dfbe20e..8cfa2b1ba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ These are the section headers that we use: - Added `list_workspaces` function (to be used as `rg.list_workspaces`, but `Workspace.list` is preferred) to list all the workspaces from an user in Argilla ([#3641](https://github.com/argilla-io/argilla/pull/3641)). - Added `list_datasets` function (to be used as `rg.list_datasets`) to list the `TextClassification`, `TokenClassification`, and `Text2Text` datasets in Argilla ([#3638](https://github.com/argilla-io/argilla/pull/3638)). - Added `workspaces list` command to list Argilla workspaces ([#3594](https://github.com/argilla-io/argilla/pull/3594)). +- Added `datasets list` command to list Argilla datasets ([#3658](https://github.com/argilla-io/argilla/pull/3658)). ### Changed diff --git a/src/argilla/__main__.py b/src/argilla/__main__.py index 382749a56d..236951f5ab 100644 --- a/src/argilla/__main__.py +++ b/src/argilla/__main__.py @@ -14,12 +14,13 @@ # limitations under the License. -from argilla.tasks import database_app, login_app, logout_app, server_app, training_app, workspaces_app +from argilla.tasks import database_app, datasets_app, login_app, logout_app, server_app, training_app, workspaces_app from argilla.tasks.async_typer import AsyncTyper app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True) app.add_typer(database_app, name="database") +app.add_typer(datasets_app, name="datasets") app.add_typer(login_app, name="login") app.add_typer(logout_app, name="logout") app.add_typer(server_app, name="server") diff --git a/src/argilla/tasks/__init__.py b/src/argilla/tasks/__init__.py index 257a7313af..6d089d4e80 100644 --- a/src/argilla/tasks/__init__.py +++ b/src/argilla/tasks/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .database import app as database_app +from .datasets import app as datasets_app from .login import app as login_app from .logout import app as logout_app from .server import app as server_app diff --git a/src/argilla/tasks/datasets/__init__.py b/src/argilla/tasks/datasets/__init__.py new file mode 100644 index 0000000000..b0ad568f34 --- /dev/null +++ b/src/argilla/tasks/datasets/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .__main__ import app + +if __name__ == "__main__": + app() diff --git a/src/argilla/tasks/datasets/__main__.py b/src/argilla/tasks/datasets/__main__.py new file mode 100644 index 0000000000..406bbd0ac6 --- /dev/null +++ b/src/argilla/tasks/datasets/__main__.py @@ -0,0 +1,29 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typer + +from argilla.tasks.callback import init_callback +from argilla.tasks.datasets.list import list_datasets + +app = typer.Typer( + help="Holds CLI commands for datasets management", invoke_without_command=True, callback=init_callback +) + + +app.command(name="list", help="List datasets linked to user's workspaces")(list_datasets) + + +if __name__ == "__main__": + app() diff --git a/src/argilla/tasks/datasets/list.py b/src/argilla/tasks/datasets/list.py new file mode 100644 index 0000000000..e0f33f61e2 --- /dev/null +++ b/src/argilla/tasks/datasets/list.py @@ -0,0 +1,74 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, Optional + +import typer + + +class DatasetType(str, Enum): + feedback = "feedback" + other = "other" + + +def list_datasets( + workspace: Optional[str] = typer.Option(None, help="List datasets in this workspace"), + type_: Optional[DatasetType] = typer.Option( + None, + "--type", + help="The type of datasets to be listed. This option can be used multiple times. By default, all datasets are listed.", + ), +) -> None: + from rich.console import Console + from rich.markdown import Markdown + + from argilla.client.api import list_datasets as list_datasets_api + from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.tasks.rich import get_argilla_themed_table + + def build_tags_text(tags: Dict[str, str]) -> Markdown: + text = "" + for tag, description in tags.items(): + text += f"- **{tag}**: {description}\n" + return Markdown(text) + + table = get_argilla_themed_table(title="Datasets") + for column in ("ID", "Name", "Workspace", "Type", "Tags", "Creation Date", "Last Update Date"): + table.add_column(column, justify="center") + + if type_ is None or type_ == DatasetType.feedback: + for dataset in FeedbackDataset.list(workspace): + # TODO: add passing value for `Creation Date` and `Update Date` columns once `RemoteFeedbackDataset` has + # these attributes + table.add_row(str(dataset.id), dataset.name, dataset.workspace.name, "Feedback", None, None, None) + + if type_ is None or type_ == DatasetType.other: + for dataset in list_datasets_api(workspace): + table.add_row( + dataset.id, + dataset.name, + dataset.workspace, + dataset.task, + build_tags_text(dataset.tags), + dataset.created_at.isoformat(sep=" "), + dataset.last_updated.isoformat(sep=" "), + ) + + console = Console() + console.print(table) + + +if __name__ == "__main__": + typer.run(list_datasets) diff --git a/src/argilla/tasks/rich.py b/src/argilla/tasks/rich.py new file mode 100644 index 0000000000..6a5053f67d --- /dev/null +++ b/src/argilla/tasks/rich.py @@ -0,0 +1,22 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from rich.table import Table + + +def get_argilla_themed_table(title: str, **kwargs: Any) -> Table: + # TODO: update colors after consulting it with UI expert + return Table(title=title, border_style="red", **kwargs) diff --git a/src/argilla/tasks/workspaces/list.py b/src/argilla/tasks/workspaces/list.py index b1b10015b0..791ff6cbaf 100644 --- a/src/argilla/tasks/workspaces/list.py +++ b/src/argilla/tasks/workspaces/list.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -from argilla.tasks import async_typer +import typer def list_workspaces() -> None: """List the workspaces in Argilla and prints them on the console.""" from rich.console import Console - from rich.table import Table from argilla import Workspace + from argilla.tasks.rich import get_argilla_themed_table workspaces = Workspace.list() - table = Table(title="Workspaces") - for column in ("ID", "Name", "Creation Date", "Update Date"): + table = get_argilla_themed_table(title="Workspaces") + for column in ("ID", "Name", "Creation Date", "Last Update Date"): table.add_column(column, justify="center") for workspace in workspaces: @@ -43,4 +41,4 @@ def list_workspaces() -> None: if __name__ == "__main__": - async_typer.run(list_workspaces) + typer.run(list_workspaces) diff --git a/tests/unit/tasks/conftest.py b/tests/unit/tasks/conftest.py index 998d85bb97..7a87be338a 100644 --- a/tests/unit/tasks/conftest.py +++ b/tests/unit/tasks/conftest.py @@ -89,6 +89,11 @@ def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession @pytest.fixture -def login_mock(mocker: "MockerFixture"): +def login_mock(mocker: "MockerFixture") -> None: mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=True) mocker.patch("argilla.client.api.ArgillaSingleton.init") + + +@pytest.fixture +def not_logged_mock(mocker: "MockerFixture") -> None: + mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=False) diff --git a/tests/unit/tasks/datasets/__init__.py b/tests/unit/tasks/datasets/__init__.py new file mode 100644 index 0000000000..55be41799b --- /dev/null +++ b/tests/unit/tasks/datasets/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/tasks/datasets/test_list.py b/tests/unit/tasks/datasets/test_list.py new file mode 100644 index 0000000000..4a1d24e942 --- /dev/null +++ b/tests/unit/tasks/datasets/test_list.py @@ -0,0 +1,155 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import TYPE_CHECKING +from unittest.mock import ANY, call +from uuid import uuid4 + +import httpx +import pytest +from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset +from argilla.client.feedback.schemas.fields import TextField +from argilla.client.feedback.schemas.questions import TextQuestion +from argilla.client.sdk.datasets.models import Dataset +from argilla.client.workspaces import Workspace +from rich.table import Table + +if TYPE_CHECKING: + from click.testing import CliRunner + from pytest_mock import MockerFixture + from typer import Typer + + +@pytest.fixture +def remote_feedback_dataset() -> "RemoteFeedbackDataset": + workspace = Workspace.__new__(Workspace) + workspace.__dict__.update( + { + "id": uuid4(), + "name": "unit-test", + "inserted_at": datetime.now(), + "updated_at": datetime.now(), + } + ) + return RemoteFeedbackDataset( + client=httpx.Client(), + id=uuid4(), + name="unit-test", + workspace=workspace, + fields=[TextField(name="prompt")], + questions=[TextQuestion(name="corrected")], + ) + + +@pytest.fixture +def dataset() -> Dataset: + return Dataset( + name="unit-test", + id="rg.unit-test", + task="TextClassification", + owner="unit-test", + workspace="unit-test", + created_at=datetime.now(), + last_updated=datetime.now(), + ) + + +@pytest.mark.usefixtures("login_mock") +class TestSuiteListDatasetsCommand: + def test_list_datasets( + self, + cli_runner: "CliRunner", + cli: "Typer", + mocker: "MockerFixture", + remote_feedback_dataset: RemoteFeedbackDataset, + dataset: Dataset, + ) -> None: + add_row_spy = mocker.spy(Table, "add_row") + feedback_dataset_list_mock = mocker.patch( + "argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] + ) + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) + + result = cli_runner.invoke(cli, "datasets list") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with(None) + list_datasets_mock.assert_called_once_with(None) + add_row_spy.assert_has_calls( + [ + call( + ANY, + str(remote_feedback_dataset.id), + remote_feedback_dataset.name, + remote_feedback_dataset.workspace.name, + "Feedback", + None, + None, + None, + ), + call( + ANY, + str(dataset.id), + dataset.name, + "unit-test", + "TextClassification", + ANY, + dataset.created_at.isoformat(sep=" "), + dataset.last_updated.isoformat(sep=" "), + ), + ] + ) + + def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --workspace unit-test") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with("unit-test") + list_datasets_mock.assert_called_once_with("unit-test") + + def test_list_datasets_using_type_feedback_filter( + self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" + ) -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --type feedback") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_called_once_with(None) + list_datasets_mock.assert_not_called() + + def test_list_datasets_using_type_other_filter( + self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" + ) -> None: + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") + + result = cli_runner.invoke(cli, "datasets list --type other") + + assert result.exit_code == 0 + feedback_dataset_list_mock.assert_not_called() + list_datasets_mock.assert_called_once_with(None) + + +@pytest.mark.usefixtures("not_logged_mock") +def test_cli_datasets_list_needs_login(cli_runner: "CliRunner", cli: "Typer") -> None: + result = cli_runner.invoke(cli, "datasets list") + + assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout + assert result.exit_code == 1 diff --git a/tests/unit/tasks/workspaces/test_list.py b/tests/unit/tasks/workspaces/test_list.py index e660682305..89c0962f7c 100644 --- a/tests/unit/tasks/workspaces/test_list.py +++ b/tests/unit/tasks/workspaces/test_list.py @@ -53,11 +53,12 @@ def test_cli_workspaces_list(cli_runner: "CliRunner", cli: "Typer", mocker: "Moc workspace.inserted_at.isoformat(sep=" "), workspace.updated_at.isoformat(sep=" "), ) - assert all(col in result.stdout for col in ("ID", "Name", "Creation Date", "Update Date", "test_workspace")) + assert all(col in result.stdout for col in ("ID", "Name", "Creation Date", "Last Update Date", "test_workspace")) assert result.exit_code == 0 -def test_cli_workspaces_list_needs_login(cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"): +@pytest.mark.usefixtures("not_logged_mock") +def test_cli_workspaces_list_needs_login(cli_runner: "CliRunner", cli: "Typer"): result = cli_runner.invoke(cli, "workspaces list") assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout