Skip to content

Commit

Permalink
feat: Allow passing workspace as client param for rg.log or `rg.loa…
Browse files Browse the repository at this point in the history
…d` (#2425)

# Description

Allow passing workspace as client parm for rglog or rgload. I also
enabled this for `rg.delete` and `rg.delete_records`.

Closes #2059 
**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)


- [X] New feature (non-breaking change which adds functionality)
- [X] Refactor (change restructuring the codebase without changing
functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)


**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [Test
A](https://github.com/argilla-io/argilla/blob/b227b09a423dae8792e520ce4f9da3dcb5fb0d0d/tests/client/test_api.py#L364)


**Checklist**

- [X] I have merged the original branch into my forked branch

- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
  • Loading branch information
davidberenstein1957 authored Mar 6, 2023
1 parent 40ca933 commit b3b897a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
22 changes: 20 additions & 2 deletions src/argilla/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def init(
def log(
records: Union[Record, Iterable[Record], Dataset],
name: str,
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
Expand All @@ -122,6 +123,8 @@ def log(
Args:
records: The record, an iterable of records, or a dataset to log.
name: The dataset name.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
Expand Down Expand Up @@ -150,6 +153,7 @@ def log(
return ArgillaSingleton.get().log(
records=records,
name=name,
workspace=workspace,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
Expand All @@ -161,6 +165,7 @@ def log(
async def log_async(
records: Union[Record, Iterable[Record], Dataset],
name: str,
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
Expand All @@ -171,6 +176,8 @@ async def log_async(
Args:
records: The record, an iterable of records, or a dataset to log.
name: The dataset name.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
Expand All @@ -192,6 +199,7 @@ async def log_async(
return await ArgillaSingleton.get().log_async(
records=records,
name=name,
workspace=workspace,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
Expand All @@ -201,6 +209,7 @@ async def log_async(

def load(
name: str,
workspace: Optional[str] = None,
query: Optional[str] = None,
vector: Optional[Tuple[str, List[float]]] = None,
ids: Optional[List[Union[str, int]]] = None,
Expand All @@ -212,6 +221,8 @@ def load(
Args:
name: The dataset name.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
query: An ElasticSearch query with the `query string
syntax <https://argilla.readthedocs.io/en/stable/guides/queries.html>`_
vector: Vector configuration for a semantic search
Expand Down Expand Up @@ -246,6 +257,7 @@ def load(
"""
return ArgillaSingleton.get().load(
name=name,
workspace=workspace,
query=query,
vector=vector,
ids=ids,
Expand Down Expand Up @@ -280,22 +292,25 @@ def copy(
)


def delete(name: str):
def delete(name: str, workspace: Optional[str] = None):
"""
Deletes a dataset.
Args:
name: The dataset name.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
Examples:
>>> import argilla as rg
>>> rg.delete(name="example-dataset")
"""
ArgillaSingleton.get().delete(name)
ArgillaSingleton.get().delete(name=name, workspace=workspace)


def delete_records(
name: str,
workspace: Optional[str] = None,
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
discard_only: bool = False,
Expand All @@ -305,6 +320,8 @@ def delete_records(
Args:
name: The dataset name.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
query: An ElasticSearch query with the `query string syntax
<https://rubrix.readthedocs.io/en/stable/guides/queries.html>`_
ids: If provided, deletes dataset records with given ids.
Expand All @@ -329,6 +346,7 @@ def delete_records(
"""
return ArgillaSingleton.get().delete_records(
name=name,
workspace=workspace,
query=query,
ids=ids,
discard_only=discard_only,
Expand Down
47 changes: 33 additions & 14 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ async def __log_internal__(api: "Argilla", *args, **kwargs):
return await api.log_async(*args, **kwargs)
except Exception as ex:
dataset = kwargs["name"]
_LOGGER.error(
f"\nCannot log data in dataset '{dataset}'\n" f"Error: {type(ex).__name__}\n" f"Details: {ex}"
)
_LOGGER.error(f"\nCannot log data in dataset '{dataset}'\nError: {type(ex).__name__}\nDetails: {ex}")
raise ex

def log(self, *args, **kwargs) -> Future:
Expand Down Expand Up @@ -169,7 +167,7 @@ def __del__(self):
def client(self) -> AuthenticatedClient:
"""The underlying authenticated HTTP client"""
warnings.warn(
message=("This prop will be removed in next release. " "Please use the http_client prop instead."),
message="This prop will be removed in next release. Please use the http_client prop instead.",
category=UserWarning,
)
return self._client
Expand Down Expand Up @@ -251,18 +249,22 @@ def copy(self, dataset: str, name_of_copy: str, workspace: str = None):
),
)

def delete(self, name: str):
def delete(self, name: str, workspace: Optional[str] = None):
"""Deletes a dataset.
Args:
name: The dataset name.
"""
if workspace is not None:
self.set_workspace(workspace)

datasets_api.delete_dataset(client=self._client, name=name)

def log(
self,
records: Union[Record, Iterable[Record], Dataset],
name: str,
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
Expand Down Expand Up @@ -290,6 +292,9 @@ def log(
will be returned instead.
"""
if workspace is not None:
self.set_workspace(workspace)

future = self._agent.log(
records=records,
name=name,
Expand All @@ -310,6 +315,7 @@ async def log_async(
self,
records: Union[Record, Iterable[Record], Dataset],
name: str,
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
Expand All @@ -332,6 +338,9 @@ async def log_async(
tags = tags or {}
metadata = metadata or {}

if workspace is not None:
self.set_workspace(workspace)

if not name:
raise InputValueError("Empty dataset name has been passed as argument.")

Expand Down Expand Up @@ -370,7 +379,7 @@ async def log_async(
bulk_class = Text2TextBulkData
creation_class = CreationText2TextRecord
else:
raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}")
raise InputValueError(f"Unknown record type {record_type}. Available values are {Record.__args__}")

processed, failed = 0, 0
with Progress() as progress_bar:
Expand Down Expand Up @@ -408,6 +417,7 @@ async def log_async(
def delete_records(
self,
name: str,
workspace: Optional[str] = None,
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
discard_only: bool = False,
Expand All @@ -432,6 +442,9 @@ def delete_records(
deletion).
"""
if workspace is not None:
self.set_workspace(workspace)

return self.datasets.delete_records(
name=name,
mark_as_discarded=discard_only,
Expand All @@ -443,6 +456,7 @@ def delete_records(
def load(
self,
name: str,
workspace: Optional[str] = None,
query: Optional[str] = None,
vector: Optional[Tuple[str, List[float]]] = None,
ids: Optional[List[Union[str, int]]] = None,
Expand All @@ -469,6 +483,9 @@ def load(
A argilla dataset.
"""
if workspace is not None:
self.set_workspace(workspace)

if as_pandas is False:
warnings.warn(
"The argument `as_pandas` is deprecated and will be removed in a future"
Expand All @@ -479,7 +496,7 @@ def load(
raise ValueError(
"The argument `as_pandas` is deprecated and will be removed in a future"
" version. Please adapt your code accordingly. ",
"If you want a pandas DataFrame do" " `rg.load('my_dataset').to_pandas()`.",
"If you want a pandas DataFrame do `rg.load('my_dataset').to_pandas()`.",
)

try:
Expand All @@ -495,13 +512,15 @@ def load(
from argilla import __version__ as version

warnings.warn(
message=f"Using python client argilla=={version},"
f" however deployed server version is {err.api_version}."
" This might lead to compatibility issues.\n"
f" Preferably, update your server version to {version}"
" or downgrade your Python API at the loss"
" of functionality and robustness via\n"
f"`pip install argilla=={err.api_version}`",
message=(
f"Using python client argilla=={version},"
f" however deployed server version is {err.api_version}."
" This might lead to compatibility issues.\n"
f" Preferably, update your server version to {version}"
" or downgrade your Python API at the loss"
" of functionality and robustness via\n"
f"`pip install argilla=={err.api_version}`"
),
category=UserWarning,
)

Expand Down
22 changes: 22 additions & 0 deletions tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,28 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_
assert record == expected


@pytest.mark.parametrize(
"records, dataset_class",
[
("singlelabel_textclassification_records", rg.DatasetForTextClassification),
],
)
def test_log_load_with_workspace(mocked_client, monkeypatch, request, records, dataset_class):
dataset_names = [
f"test_general_log_load_{dataset_class.__name__.lower()}_" + input_type
for input_type in ["single", "list", "dataset"]
]
for name in dataset_names:
mocked_client.delete(f"/api/datasets/{name}")

records = request.getfixturevalue(records)

api.log(records, name=dataset_names[0], workspace="argilla")
ds = api.load(dataset_names[0], workspace="argilla")
api.delete_records(dataset_names[0], ids=[rec.id for rec in ds][:1], workspace="argilla")
api.delete(dataset_names[0], workspace="argilla")


def test_passing_wrong_iterable_data(mocked_client):
dataset_name = "test_log_single_records"
mocked_client.delete(f"/api/datasets/{dataset_name}")
Expand Down

0 comments on commit b3b897a

Please sign in to comment.