From b3b897ac731d0113ea5e078b5bf5974f351935fd Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 6 Mar 2023 13:59:30 +0100 Subject: [PATCH] feat: Allow passing workspace as client param for `rg.log` or `rg.load` (#2425) # Description Allow passing workspace as client parm for rglog or rgload. I also enabled this for `rg.delete` and `rg.delete_records`. Closes #2059 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] New feature (non-breaking change which adds functionality) - [X] Refactor (change restructuring the codebase without changing functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [Test A](https://github.com/argilla-io/argilla/blob/b227b09a423dae8792e520ce4f9da3dcb5fb0d0d/tests/client/test_api.py#L364) **Checklist** - [X] I have merged the original branch into my forked branch - [X] follows the style guidelines of this project - [X] I did a self-review of my code - [] I made corresponding changes to the documentation - [X] My changes generate no new warnings - [X] I have added tests that prove my fix is effective or that my feature works --- src/argilla/client/api.py | 22 +++++++++++++++-- src/argilla/client/client.py | 47 +++++++++++++++++++++++++----------- tests/client/test_api.py | 22 +++++++++++++++++ 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 3bce617e35..cb55e0a7cf 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -109,6 +109,7 @@ def init( def log( records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -122,6 +123,8 @@ def log( Args: records: The record, an iterable of records, or a dataset to log. name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. chunk_size: The chunk size for a data bulk. @@ -150,6 +153,7 @@ def log( return ArgillaSingleton.get().log( records=records, name=name, + workspace=workspace, tags=tags, metadata=metadata, chunk_size=chunk_size, @@ -161,6 +165,7 @@ def log( async def log_async( records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -171,6 +176,8 @@ async def log_async( Args: records: The record, an iterable of records, or a dataset to log. name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. chunk_size: The chunk size for a data bulk. @@ -192,6 +199,7 @@ async def log_async( return await ArgillaSingleton.get().log_async( records=records, name=name, + workspace=workspace, tags=tags, metadata=metadata, chunk_size=chunk_size, @@ -201,6 +209,7 @@ async def log_async( def load( name: str, + workspace: Optional[str] = None, query: Optional[str] = None, vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, @@ -212,6 +221,8 @@ def load( Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. query: An ElasticSearch query with the `query string syntax `_ vector: Vector configuration for a semantic search @@ -246,6 +257,7 @@ def load( """ return ArgillaSingleton.get().load( name=name, + workspace=workspace, query=query, vector=vector, ids=ids, @@ -280,22 +292,25 @@ def copy( ) -def delete(name: str): +def delete(name: str, workspace: Optional[str] = None): """ Deletes a dataset. Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. Examples: >>> import argilla as rg >>> rg.delete(name="example-dataset") """ - ArgillaSingleton.get().delete(name) + ArgillaSingleton.get().delete(name=name, workspace=workspace) def delete_records( name: str, + workspace: Optional[str] = None, query: Optional[str] = None, ids: Optional[List[Union[str, int]]] = None, discard_only: bool = False, @@ -305,6 +320,8 @@ def delete_records( Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. query: An ElasticSearch query with the `query string syntax `_ ids: If provided, deletes dataset records with given ids. @@ -329,6 +346,7 @@ def delete_records( """ return ArgillaSingleton.get().delete_records( name=name, + workspace=workspace, query=query, ids=ids, discard_only=discard_only, diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 1df03aca60..62e0d18c36 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -100,9 +100,7 @@ async def __log_internal__(api: "Argilla", *args, **kwargs): return await api.log_async(*args, **kwargs) except Exception as ex: dataset = kwargs["name"] - _LOGGER.error( - f"\nCannot log data in dataset '{dataset}'\n" f"Error: {type(ex).__name__}\n" f"Details: {ex}" - ) + _LOGGER.error(f"\nCannot log data in dataset '{dataset}'\nError: {type(ex).__name__}\nDetails: {ex}") raise ex def log(self, *args, **kwargs) -> Future: @@ -169,7 +167,7 @@ def __del__(self): def client(self) -> AuthenticatedClient: """The underlying authenticated HTTP client""" warnings.warn( - message=("This prop will be removed in next release. " "Please use the http_client prop instead."), + message="This prop will be removed in next release. Please use the http_client prop instead.", category=UserWarning, ) return self._client @@ -251,18 +249,22 @@ def copy(self, dataset: str, name_of_copy: str, workspace: str = None): ), ) - def delete(self, name: str): + def delete(self, name: str, workspace: Optional[str] = None): """Deletes a dataset. Args: name: The dataset name. """ + if workspace is not None: + self.set_workspace(workspace) + datasets_api.delete_dataset(client=self._client, name=name) def log( self, records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -290,6 +292,9 @@ def log( will be returned instead. """ + if workspace is not None: + self.set_workspace(workspace) + future = self._agent.log( records=records, name=name, @@ -310,6 +315,7 @@ async def log_async( self, records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -332,6 +338,9 @@ async def log_async( tags = tags or {} metadata = metadata or {} + if workspace is not None: + self.set_workspace(workspace) + if not name: raise InputValueError("Empty dataset name has been passed as argument.") @@ -370,7 +379,7 @@ async def log_async( bulk_class = Text2TextBulkData creation_class = CreationText2TextRecord else: - raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}") + raise InputValueError(f"Unknown record type {record_type}. Available values are {Record.__args__}") processed, failed = 0, 0 with Progress() as progress_bar: @@ -408,6 +417,7 @@ async def log_async( def delete_records( self, name: str, + workspace: Optional[str] = None, query: Optional[str] = None, ids: Optional[List[Union[str, int]]] = None, discard_only: bool = False, @@ -432,6 +442,9 @@ def delete_records( deletion). """ + if workspace is not None: + self.set_workspace(workspace) + return self.datasets.delete_records( name=name, mark_as_discarded=discard_only, @@ -443,6 +456,7 @@ def delete_records( def load( self, name: str, + workspace: Optional[str] = None, query: Optional[str] = None, vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, @@ -469,6 +483,9 @@ def load( A argilla dataset. """ + if workspace is not None: + self.set_workspace(workspace) + if as_pandas is False: warnings.warn( "The argument `as_pandas` is deprecated and will be removed in a future" @@ -479,7 +496,7 @@ def load( raise ValueError( "The argument `as_pandas` is deprecated and will be removed in a future" " version. Please adapt your code accordingly. ", - "If you want a pandas DataFrame do" " `rg.load('my_dataset').to_pandas()`.", + "If you want a pandas DataFrame do `rg.load('my_dataset').to_pandas()`.", ) try: @@ -495,13 +512,15 @@ def load( from argilla import __version__ as version warnings.warn( - message=f"Using python client argilla=={version}," - f" however deployed server version is {err.api_version}." - " This might lead to compatibility issues.\n" - f" Preferably, update your server version to {version}" - " or downgrade your Python API at the loss" - " of functionality and robustness via\n" - f"`pip install argilla=={err.api_version}`", + message=( + f"Using python client argilla=={version}," + f" however deployed server version is {err.api_version}." + " This might lead to compatibility issues.\n" + f" Preferably, update your server version to {version}" + " or downgrade your Python API at the loss" + " of functionality and robustness via\n" + f"`pip install argilla=={err.api_version}`" + ), category=UserWarning, ) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 7aaff4e86c..63299145df 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -372,6 +372,28 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_ assert record == expected +@pytest.mark.parametrize( + "records, dataset_class", + [ + ("singlelabel_textclassification_records", rg.DatasetForTextClassification), + ], +) +def test_log_load_with_workspace(mocked_client, monkeypatch, request, records, dataset_class): + dataset_names = [ + f"test_general_log_load_{dataset_class.__name__.lower()}_" + input_type + for input_type in ["single", "list", "dataset"] + ] + for name in dataset_names: + mocked_client.delete(f"/api/datasets/{name}") + + records = request.getfixturevalue(records) + + api.log(records, name=dataset_names[0], workspace="argilla") + ds = api.load(dataset_names[0], workspace="argilla") + api.delete_records(dataset_names[0], ids=[rec.id for rec in ds][:1], workspace="argilla") + api.delete(dataset_names[0], workspace="argilla") + + def test_passing_wrong_iterable_data(mocked_client): dataset_name = "test_log_single_records" mocked_client.delete(f"/api/datasets/{dataset_name}")