Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add datasets list command #3658

Merged
merged 9 commits into from
Aug 30, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ These are the section headers that we use:
- Added `list_workspaces` function (to be used as `rg.list_workspaces`, but `Workspace.list` is preferred) to list all the workspaces from an user in Argilla ([#3641](https://github.com/argilla-io/argilla/pull/3641)).
- Added `list_datasets` function (to be used as `rg.list_datasets`) to list the `TextClassification`, `TokenClassification`, and `Text2Text` datasets in Argilla ([#3638](https://github.com/argilla-io/argilla/pull/3638)).
- Added `workspaces list` command to list Argilla workspaces ([#3594](https://github.com/argilla-io/argilla/pull/3594)).
- Added `datasets list` command to list Argilla datasets ([#3658](https://github.com/argilla-io/argilla/pull/3658)).
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved

### Changed

Expand Down
3 changes: 2 additions & 1 deletion src/argilla/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.


from argilla.tasks import database_app, login_app, logout_app, server_app, training_app, workspaces_app
from argilla.tasks import database_app, datasets_app, login_app, logout_app, server_app, training_app, workspaces_app
from argilla.tasks.async_typer import AsyncTyper

app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)

app.add_typer(database_app, name="database")
app.add_typer(datasets_app, name="datasets")
app.add_typer(login_app, name="login")
app.add_typer(logout_app, name="logout")
app.add_typer(server_app, name="server")
Expand Down
1 change: 1 addition & 0 deletions src/argilla/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .database import app as database_app
from .datasets import app as datasets_app
from .login import app as login_app
from .logout import app as logout_app
from .server import app as server_app
Expand Down
18 changes: 18 additions & 0 deletions src/argilla/tasks/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .__main__ import app

if __name__ == "__main__":
app()
29 changes: 29 additions & 0 deletions src/argilla/tasks/datasets/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typer

from argilla.tasks.callback import init_callback
from argilla.tasks.datasets.list import list_datasets

app = typer.Typer(
help="Holds CLI commands for datasets management", invoke_without_command=True, callback=init_callback
)


app.command(name="list", help="List datasets linked to user's workspaces")(list_datasets)


if __name__ == "__main__":
app()
74 changes: 74 additions & 0 deletions src/argilla/tasks/datasets/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Dict, Optional

import typer


class DatasetType(str, Enum):
feedback = "feedback"
other = "other"


def list_datasets(
workspace: Optional[str] = typer.Option(None, help="List datasets in this workspace"),
type_: Optional[DatasetType] = typer.Option(
None,
"--type",
help="The type of datasets to be listed. This option can be used multiple times. By default, all datasets are listed.",
),
) -> None:
from rich.console import Console
from rich.markdown import Markdown

from argilla.client.api import list_datasets as list_datasets_api
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.tasks.rich import get_argilla_themed_table

def build_tags_text(tags: Dict[str, str]) -> Markdown:
text = ""
for tag, description in tags.items():
text += f"- **{tag}**: {description}\n"

Check warning on line 44 in src/argilla/tasks/datasets/list.py

View check run for this annotation

Codecov / codecov/patch

src/argilla/tasks/datasets/list.py#L44

Added line #L44 was not covered by tests
return Markdown(text)

table = get_argilla_themed_table(title="Datasets")
for column in ("ID", "Name", "Workspace", "Type", "Tags", "Creation Date", "Last Update Date"):
table.add_column(column, justify="center")

if type_ is None or type_ == DatasetType.feedback:
for dataset in FeedbackDataset.list(workspace):
# TODO: add passing value for `Creation Date` and `Update Date` columns once `RemoteFeedbackDataset` has
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
# these attributes
table.add_row(str(dataset.id), dataset.name, dataset.workspace.name, "Feedback", None, None, None)

if type_ is None or type_ == DatasetType.other:
for dataset in list_datasets_api(workspace):
table.add_row(
dataset.id,
dataset.name,
dataset.workspace,
dataset.task,
build_tags_text(dataset.tags),
dataset.created_at.isoformat(sep=" "),
dataset.last_updated.isoformat(sep=" "),
)

console = Console()
console.print(table)


if __name__ == "__main__":
typer.run(list_datasets)
22 changes: 22 additions & 0 deletions src/argilla/tasks/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from rich.table import Table


def get_argilla_themed_table(title: str, **kwargs: Any) -> Table:
# TODO: update colors after consulting it with UI expert
return Table(title=title, border_style="red", **kwargs)
12 changes: 5 additions & 7 deletions src/argilla/tasks/workspaces/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from argilla.tasks import async_typer
import typer


def list_workspaces() -> None:
"""List the workspaces in Argilla and prints them on the console."""
from rich.console import Console
from rich.table import Table

from argilla import Workspace
from argilla.tasks.rich import get_argilla_themed_table

workspaces = Workspace.list()

table = Table(title="Workspaces")
for column in ("ID", "Name", "Creation Date", "Update Date"):
table = get_argilla_themed_table(title="Workspaces")
for column in ("ID", "Name", "Creation Date", "Last Update Date"):
table.add_column(column, justify="center")

for workspace in workspaces:
Expand All @@ -43,4 +41,4 @@ def list_workspaces() -> None:


if __name__ == "__main__":
async_typer.run(list_workspaces)
typer.run(list_workspaces)
7 changes: 6 additions & 1 deletion tests/unit/tasks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession


@pytest.fixture
def login_mock(mocker: "MockerFixture"):
def login_mock(mocker: "MockerFixture") -> None:
mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=True)
mocker.patch("argilla.client.api.ArgillaSingleton.init")


@pytest.fixture
def not_logged_mock(mocker: "MockerFixture") -> None:
mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=False)
13 changes: 13 additions & 0 deletions tests/unit/tasks/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
155 changes: 155 additions & 0 deletions tests/unit/tasks/datasets/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import TYPE_CHECKING
from unittest.mock import ANY, call
from uuid import uuid4

import httpx
import pytest
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.questions import TextQuestion
from argilla.client.sdk.datasets.models import Dataset
from argilla.client.workspaces import Workspace
from rich.table import Table

if TYPE_CHECKING:
from click.testing import CliRunner
from pytest_mock import MockerFixture
from typer import Typer


@pytest.fixture
def remote_feedback_dataset() -> "RemoteFeedbackDataset":
workspace = Workspace.__new__(Workspace)
workspace.__dict__.update(
{
"id": uuid4(),
"name": "unit-test",
"inserted_at": datetime.now(),
"updated_at": datetime.now(),
}
)
return RemoteFeedbackDataset(
client=httpx.Client(),
id=uuid4(),
name="unit-test",
workspace=workspace,
fields=[TextField(name="prompt")],
questions=[TextQuestion(name="corrected")],
)


@pytest.fixture
def dataset() -> Dataset:
return Dataset(
name="unit-test",
id="rg.unit-test",
task="TextClassification",
owner="unit-test",
workspace="unit-test",
created_at=datetime.now(),
last_updated=datetime.now(),
)


@pytest.mark.usefixtures("login_mock")
class TestSuiteListDatasetsCommand:
def test_list_datasets(
self,
cli_runner: "CliRunner",
cli: "Typer",
mocker: "MockerFixture",
remote_feedback_dataset: RemoteFeedbackDataset,
dataset: Dataset,
) -> None:
add_row_spy = mocker.spy(Table, "add_row")
feedback_dataset_list_mock = mocker.patch(
"argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset]
)
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset])

result = cli_runner.invoke(cli, "datasets list")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with(None)
list_datasets_mock.assert_called_once_with(None)
add_row_spy.assert_has_calls(
[
call(
ANY,
str(remote_feedback_dataset.id),
remote_feedback_dataset.name,
remote_feedback_dataset.workspace.name,
"Feedback",
None,
None,
None,
),
call(
ANY,
str(dataset.id),
dataset.name,
"unit-test",
"TextClassification",
ANY,
dataset.created_at.isoformat(sep=" "),
dataset.last_updated.isoformat(sep=" "),
),
]
)

def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --workspace unit-test")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with("unit-test")
list_datasets_mock.assert_called_once_with("unit-test")

def test_list_datasets_using_type_feedback_filter(
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"
) -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --type feedback")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with(None)
list_datasets_mock.assert_not_called()

def test_list_datasets_using_type_other_filter(
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"
) -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --type other")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_not_called()
list_datasets_mock.assert_called_once_with(None)


@pytest.mark.usefixtures("not_logged_mock")
def test_cli_datasets_list_needs_login(cli_runner: "CliRunner", cli: "Typer") -> None:
result = cli_runner.invoke(cli, "datasets list")

assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout
assert result.exit_code == 1
Loading
Loading