-
Notifications
You must be signed in to change notification settings - Fork 377
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
datasets list
command (#3658)
# Description This PR adds a new command `python -m argilla datasets list` which allows to list the datasets from the logged Argilla server in the terminal. <img width="1512" alt="image" src="https://github.com/argilla-io/argilla/assets/29572918/5322b963-66aa-4926-bbb7-8fb9844a3365"> Closes #3591 **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** I've created a custom environment, used the `python -m argilla login` command to login in it and then list the datasets from this environment using the new command. All the datasets were listed. Applying workspace filter worked. Applying dataset kind filter worked. Additionally, I've added unit tests to cover all the additions. **Checklist** - [ ] I added relevant documentation - [x] I followed the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: alvarobartt <[email protected]>
- Loading branch information
1 parent
bfea5c1
commit 3c9b5d0
Showing
12 changed files
with
329 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .__main__ import app | ||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import typer | ||
|
||
from argilla.tasks.callback import init_callback | ||
from argilla.tasks.datasets.list import list_datasets | ||
|
||
app = typer.Typer( | ||
help="Holds CLI commands for datasets management", invoke_without_command=True, callback=init_callback | ||
) | ||
|
||
|
||
app.command(name="list", help="List datasets linked to user's workspaces")(list_datasets) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from enum import Enum | ||
from typing import Dict, Optional | ||
|
||
import typer | ||
|
||
|
||
class DatasetType(str, Enum): | ||
feedback = "feedback" | ||
other = "other" | ||
|
||
|
||
def list_datasets( | ||
workspace: Optional[str] = typer.Option(None, help="List datasets in this workspace"), | ||
type_: Optional[DatasetType] = typer.Option( | ||
None, | ||
"--type", | ||
help="The type of datasets to be listed. This option can be used multiple times. By default, all datasets are listed.", | ||
), | ||
) -> None: | ||
from rich.console import Console | ||
from rich.markdown import Markdown | ||
|
||
from argilla.client.api import list_datasets as list_datasets_api | ||
from argilla.client.feedback.dataset.local import FeedbackDataset | ||
from argilla.tasks.rich import get_argilla_themed_table | ||
|
||
def build_tags_text(tags: Dict[str, str]) -> Markdown: | ||
text = "" | ||
for tag, description in tags.items(): | ||
text += f"- **{tag}**: {description}\n" | ||
return Markdown(text) | ||
|
||
table = get_argilla_themed_table(title="Datasets") | ||
for column in ("ID", "Name", "Workspace", "Type", "Tags", "Creation Date", "Last Update Date"): | ||
table.add_column(column, justify="center") | ||
|
||
if type_ is None or type_ == DatasetType.feedback: | ||
for dataset in FeedbackDataset.list(workspace): | ||
# TODO: add passing value for `Creation Date` and `Update Date` columns once `RemoteFeedbackDataset` has | ||
# these attributes | ||
table.add_row(str(dataset.id), dataset.name, dataset.workspace.name, "Feedback", None, None, None) | ||
|
||
if type_ is None or type_ == DatasetType.other: | ||
for dataset in list_datasets_api(workspace): | ||
table.add_row( | ||
dataset.id, | ||
dataset.name, | ||
dataset.workspace, | ||
dataset.task, | ||
build_tags_text(dataset.tags), | ||
dataset.created_at.isoformat(sep=" "), | ||
dataset.last_updated.isoformat(sep=" "), | ||
) | ||
|
||
console = Console() | ||
console.print(table) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(list_datasets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any | ||
|
||
from rich.table import Table | ||
|
||
|
||
def get_argilla_themed_table(title: str, **kwargs: Any) -> Table: | ||
# TODO: update colors after consulting it with UI expert | ||
return Table(title=title, border_style="red", **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2021-present, the Recognai S.L. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from datetime import datetime | ||
from typing import TYPE_CHECKING | ||
from unittest.mock import ANY, call | ||
from uuid import uuid4 | ||
|
||
import httpx | ||
import pytest | ||
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset | ||
from argilla.client.feedback.schemas.fields import TextField | ||
from argilla.client.feedback.schemas.questions import TextQuestion | ||
from argilla.client.sdk.datasets.models import Dataset | ||
from argilla.client.workspaces import Workspace | ||
from rich.table import Table | ||
|
||
if TYPE_CHECKING: | ||
from click.testing import CliRunner | ||
from pytest_mock import MockerFixture | ||
from typer import Typer | ||
|
||
|
||
@pytest.fixture | ||
def remote_feedback_dataset() -> "RemoteFeedbackDataset": | ||
workspace = Workspace.__new__(Workspace) | ||
workspace.__dict__.update( | ||
{ | ||
"id": uuid4(), | ||
"name": "unit-test", | ||
"inserted_at": datetime.now(), | ||
"updated_at": datetime.now(), | ||
} | ||
) | ||
return RemoteFeedbackDataset( | ||
client=httpx.Client(), | ||
id=uuid4(), | ||
name="unit-test", | ||
workspace=workspace, | ||
fields=[TextField(name="prompt")], | ||
questions=[TextQuestion(name="corrected")], | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dataset() -> Dataset: | ||
return Dataset( | ||
name="unit-test", | ||
id="rg.unit-test", | ||
task="TextClassification", | ||
owner="unit-test", | ||
workspace="unit-test", | ||
created_at=datetime.now(), | ||
last_updated=datetime.now(), | ||
) | ||
|
||
|
||
@pytest.mark.usefixtures("login_mock") | ||
class TestSuiteListDatasetsCommand: | ||
def test_list_datasets( | ||
self, | ||
cli_runner: "CliRunner", | ||
cli: "Typer", | ||
mocker: "MockerFixture", | ||
remote_feedback_dataset: RemoteFeedbackDataset, | ||
dataset: Dataset, | ||
) -> None: | ||
add_row_spy = mocker.spy(Table, "add_row") | ||
feedback_dataset_list_mock = mocker.patch( | ||
"argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] | ||
) | ||
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) | ||
|
||
result = cli_runner.invoke(cli, "datasets list") | ||
|
||
assert result.exit_code == 0 | ||
feedback_dataset_list_mock.assert_called_once_with(None) | ||
list_datasets_mock.assert_called_once_with(None) | ||
add_row_spy.assert_has_calls( | ||
[ | ||
call( | ||
ANY, | ||
str(remote_feedback_dataset.id), | ||
remote_feedback_dataset.name, | ||
remote_feedback_dataset.workspace.name, | ||
"Feedback", | ||
None, | ||
None, | ||
None, | ||
), | ||
call( | ||
ANY, | ||
str(dataset.id), | ||
dataset.name, | ||
"unit-test", | ||
"TextClassification", | ||
ANY, | ||
dataset.created_at.isoformat(sep=" "), | ||
dataset.last_updated.isoformat(sep=" "), | ||
), | ||
] | ||
) | ||
|
||
def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: | ||
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") | ||
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") | ||
|
||
result = cli_runner.invoke(cli, "datasets list --workspace unit-test") | ||
|
||
assert result.exit_code == 0 | ||
feedback_dataset_list_mock.assert_called_once_with("unit-test") | ||
list_datasets_mock.assert_called_once_with("unit-test") | ||
|
||
def test_list_datasets_using_type_feedback_filter( | ||
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" | ||
) -> None: | ||
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") | ||
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") | ||
|
||
result = cli_runner.invoke(cli, "datasets list --type feedback") | ||
|
||
assert result.exit_code == 0 | ||
feedback_dataset_list_mock.assert_called_once_with(None) | ||
list_datasets_mock.assert_not_called() | ||
|
||
def test_list_datasets_using_type_other_filter( | ||
self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" | ||
) -> None: | ||
feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") | ||
list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") | ||
|
||
result = cli_runner.invoke(cli, "datasets list --type other") | ||
|
||
assert result.exit_code == 0 | ||
feedback_dataset_list_mock.assert_not_called() | ||
list_datasets_mock.assert_called_once_with(None) | ||
|
||
|
||
@pytest.mark.usefixtures("not_logged_mock") | ||
def test_cli_datasets_list_needs_login(cli_runner: "CliRunner", cli: "Typer") -> None: | ||
result = cli_runner.invoke(cli, "datasets list") | ||
|
||
assert "You are not logged in. Please run `argilla login` to login to an Argilla server." in result.stdout | ||
assert result.exit_code == 1 |
Oops, something went wrong.