Skip to content

Commit

Permalink
feat: add datasets list command (#3658)
Browse files Browse the repository at this point in the history
# Description

This PR adds a new command `python -m argilla datasets list` which
allows to list the datasets from the logged Argilla server in the
terminal.

<img width="1512" alt="image"
src="https://github.com/argilla-io/argilla/assets/29572918/5322b963-66aa-4926-bbb7-8fb9844a3365">

Closes #3591

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

I've created a custom environment, used the `python -m argilla login`
command to login in it and then list the datasets from this environment
using the new command. All the datasets were listed. Applying workspace
filter worked. Applying dataset kind filter worked. Additionally, I've
added unit tests to cover all the additions.

**Checklist**

- [ ] I added relevant documentation
- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: alvarobartt <[email protected]>
  • Loading branch information
gabrielmbmb and alvarobartt authored Aug 30, 2023
1 parent bfea5c1 commit 3c9b5d0
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ These are the section headers that we use:
- Added `list_workspaces` function (to be used as `rg.list_workspaces`, but `Workspace.list` is preferred) to list all the workspaces from an user in Argilla ([#3641](https://github.com/argilla-io/argilla/pull/3641)).
- Added `list_datasets` function (to be used as `rg.list_datasets`) to list the `TextClassification`, `TokenClassification`, and `Text2Text` datasets in Argilla ([#3638](https://github.com/argilla-io/argilla/pull/3638)).
- Added `workspaces list` command to list Argilla workspaces ([#3594](https://github.com/argilla-io/argilla/pull/3594)).
- Added `datasets list` command to list Argilla datasets ([#3658](https://github.com/argilla-io/argilla/pull/3658)).

### Changed

Expand Down
3 changes: 2 additions & 1 deletion src/argilla/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.


from argilla.tasks import database_app, login_app, logout_app, server_app, training_app, workspaces_app
from argilla.tasks import database_app, datasets_app, login_app, logout_app, server_app, training_app, workspaces_app
from argilla.tasks.async_typer import AsyncTyper

app = AsyncTyper(rich_help_panel=True, help="Argilla CLI", no_args_is_help=True)

app.add_typer(database_app, name="database")
app.add_typer(datasets_app, name="datasets")
app.add_typer(login_app, name="login")
app.add_typer(logout_app, name="logout")
app.add_typer(server_app, name="server")
Expand Down
1 change: 1 addition & 0 deletions src/argilla/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .database import app as database_app
from .datasets import app as datasets_app
from .login import app as login_app
from .logout import app as logout_app
from .server import app as server_app
Expand Down
18 changes: 18 additions & 0 deletions src/argilla/tasks/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .__main__ import app

if __name__ == "__main__":
app()
29 changes: 29 additions & 0 deletions src/argilla/tasks/datasets/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typer

from argilla.tasks.callback import init_callback
from argilla.tasks.datasets.list import list_datasets

app = typer.Typer(
help="Holds CLI commands for datasets management", invoke_without_command=True, callback=init_callback
)


app.command(name="list", help="List datasets linked to user's workspaces")(list_datasets)


if __name__ == "__main__":
app()
74 changes: 74 additions & 0 deletions src/argilla/tasks/datasets/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Dict, Optional

import typer


class DatasetType(str, Enum):
feedback = "feedback"
other = "other"


def list_datasets(
workspace: Optional[str] = typer.Option(None, help="List datasets in this workspace"),
type_: Optional[DatasetType] = typer.Option(
None,
"--type",
help="The type of datasets to be listed. This option can be used multiple times. By default, all datasets are listed.",
),
) -> None:
from rich.console import Console
from rich.markdown import Markdown

from argilla.client.api import list_datasets as list_datasets_api
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.tasks.rich import get_argilla_themed_table

def build_tags_text(tags: Dict[str, str]) -> Markdown:
text = ""
for tag, description in tags.items():
text += f"- **{tag}**: {description}\n"
return Markdown(text)

table = get_argilla_themed_table(title="Datasets")
for column in ("ID", "Name", "Workspace", "Type", "Tags", "Creation Date", "Last Update Date"):
table.add_column(column, justify="center")

if type_ is None or type_ == DatasetType.feedback:
for dataset in FeedbackDataset.list(workspace):
# TODO: add passing value for `Creation Date` and `Update Date` columns once `RemoteFeedbackDataset` has
# these attributes
table.add_row(str(dataset.id), dataset.name, dataset.workspace.name, "Feedback", None, None, None)

if type_ is None or type_ == DatasetType.other:
for dataset in list_datasets_api(workspace):
table.add_row(
dataset.id,
dataset.name,
dataset.workspace,
dataset.task,
build_tags_text(dataset.tags),
dataset.created_at.isoformat(sep=" "),
dataset.last_updated.isoformat(sep=" "),
)

console = Console()
console.print(table)


if __name__ == "__main__":
typer.run(list_datasets)
22 changes: 22 additions & 0 deletions src/argilla/tasks/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from rich.table import Table


def get_argilla_themed_table(title: str, **kwargs: Any) -> Table:
# TODO: update colors after consulting it with UI expert
return Table(title=title, border_style="red", **kwargs)
12 changes: 5 additions & 7 deletions src/argilla/tasks/workspaces/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from argilla.tasks import async_typer
import typer


def list_workspaces() -> None:
"""List the workspaces in Argilla and prints them on the console."""
from rich.console import Console
from rich.table import Table

from argilla import Workspace
from argilla.tasks.rich import get_argilla_themed_table

workspaces = Workspace.list()

table = Table(title="Workspaces")
for column in ("ID", "Name", "Creation Date", "Update Date"):
table = get_argilla_themed_table(title="Workspaces")
for column in ("ID", "Name", "Creation Date", "Last Update Date"):
table.add_column(column, justify="center")

for workspace in workspaces:
Expand All @@ -43,4 +41,4 @@ def list_workspaces() -> None:


if __name__ == "__main__":
async_typer.run(list_workspaces)
typer.run(list_workspaces)
7 changes: 6 additions & 1 deletion tests/unit/tasks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def async_db_proxy(mocker: "MockerFixture", sync_db: "Session") -> "AsyncSession


@pytest.fixture
def login_mock(mocker: "MockerFixture"):
def login_mock(mocker: "MockerFixture") -> None:
mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=True)
mocker.patch("argilla.client.api.ArgillaSingleton.init")


@pytest.fixture
def not_logged_mock(mocker: "MockerFixture") -> None:
mocker.patch("argilla.client.login.ArgillaCredentials.exists", return_value=False)
13 changes: 13 additions & 0 deletions tests/unit/tasks/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
155 changes: 155 additions & 0 deletions tests/unit/tasks/datasets/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import TYPE_CHECKING
from unittest.mock import ANY, call
from uuid import uuid4

import httpx
import pytest
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.questions import TextQuestion
from argilla.client.sdk.datasets.models import Dataset
from argilla.client.workspaces import Workspace
from rich.table import Table

if TYPE_CHECKING:
from click.testing import CliRunner
from pytest_mock import MockerFixture
from typer import Typer


@pytest.fixture
def remote_feedback_dataset() -> "RemoteFeedbackDataset":
workspace = Workspace.__new__(Workspace)
workspace.__dict__.update(
{
"id": uuid4(),
"name": "unit-test",
"inserted_at": datetime.now(),
"updated_at": datetime.now(),
}
)
return RemoteFeedbackDataset(
client=httpx.Client(),
id=uuid4(),
name="unit-test",
workspace=workspace,
fields=[TextField(name="prompt")],
questions=[TextQuestion(name="corrected")],
)


@pytest.fixture
def dataset() -> Dataset:
return Dataset(
name="unit-test",
id="rg.unit-test",
task="TextClassification",
owner="unit-test",
workspace="unit-test",
created_at=datetime.now(),
last_updated=datetime.now(),
)


@pytest.mark.usefixtures("login_mock")
class TestSuiteListDatasetsCommand:
def test_list_datasets(
self,
cli_runner: "CliRunner",
cli: "Typer",
mocker: "MockerFixture",
remote_feedback_dataset: RemoteFeedbackDataset,
dataset: Dataset,
) -> None:
add_row_spy = mocker.spy(Table, "add_row")
feedback_dataset_list_mock = mocker.patch(
"argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset]
)
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset])

result = cli_runner.invoke(cli, "datasets list")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with(None)
list_datasets_mock.assert_called_once_with(None)
add_row_spy.assert_has_calls(
[
call(
ANY,
str(remote_feedback_dataset.id),
remote_feedback_dataset.name,
remote_feedback_dataset.workspace.name,
"Feedback",
None,
None,
None,
),
call(
ANY,
str(dataset.id),
dataset.name,
"unit-test",
"TextClassification",
ANY,
dataset.created_at.isoformat(sep=" "),
dataset.last_updated.isoformat(sep=" "),
),
]
)

def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --workspace unit-test")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with("unit-test")
list_datasets_mock.assert_called_once_with("unit-test")

def test_list_datasets_using_type_feedback_filter(
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"
) -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --type feedback")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_called_once_with(None)
list_datasets_mock.assert_not_called()

def test_list_datasets_using_type_other_filter(
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture"
) -> None:
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list")
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets")

result = cli_runner.invoke(cli, "datasets list --type other")

assert result.exit_code == 0
feedback_dataset_list_mock.assert_not_called()
list_datasets_mock.assert_called_once_with(None)


@pytest.mark.usefixtures("not_logged_mock")
def test_cli_datasets_list_needs_login(cli_runner: "CliRunner", cli: "Typer") -> None:
result = cli_runner.invoke(cli, "datasets list")

assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout
assert result.exit_code == 1
Loading

0 comments on commit 3c9b5d0

Please sign in to comment.