diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 3bce617e35..74f0c1a16d 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -111,9 +111,10 @@ def log( name: str, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, background: bool = False, + chunk_size: Optional[int] = None, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -124,10 +125,11 @@ def log( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API. @@ -152,9 +154,10 @@ def log( name=name, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, background=background, + chunk_size=chunk_size, ) @@ -163,8 +166,9 @@ async def log_async( name: str, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, + chunk_size: Optional[int] = None, ) -> BulkResponse: """Logs Records to argilla with asyncio. @@ -173,8 +177,9 @@ async def log_async( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API @@ -194,8 +199,9 @@ async def log_async( name=name, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, + chunk_size=chunk_size, ) diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 66003c8729..28e67fe2c4 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -115,7 +115,7 @@ class Argilla: """ # Larger sizes will trigger a warning - _MAX_CHUNK_SIZE = 5000 + _MAX_BATCH_SIZE = 5000 def __init__( self, @@ -256,9 +256,10 @@ def log( name: str, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, background: bool = False, + chunk_size: Optional[int] = None, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -269,11 +270,12 @@ def log( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API. @@ -286,8 +288,9 @@ def log( name=name, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, + chunk_size=chunk_size, ) if background: return future @@ -303,8 +306,9 @@ async def log_async( name: str, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, + chunk_size: Optional[int] = None, ) -> BulkResponse: """Logs Records to argilla with asyncio. @@ -313,8 +317,9 @@ async def log_async( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API @@ -335,11 +340,18 @@ async def log_async( " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" ) - if chunk_size > self._MAX_CHUNK_SIZE: + if chunk_size is not None: + warnings.warn( + "The argument `chunk_size` is deprecated and will be removed in a future" + " version. Please use `batch_size` instead.", + FutureWarning, + ) + batch_size = chunk_size + + if batch_size > self._MAX_BATCH_SIZE: _LOGGER.warning( - """The introduced chunk size is noticeably large, timeout errors may occur. - Consider a chunk size smaller than %s""", - self._MAX_CHUNK_SIZE, + "The requested batch size is noticeably large, timeout errors may occur. " + f"Consider a batch size smaller than {self._MAX_BATCH_SIZE}", ) if isinstance(records, Record.__args__): @@ -367,8 +379,8 @@ async def log_async( with Progress() as progress_bar: task = progress_bar.add_task("Logging...", total=len(records), visible=verbose) - for i in range(0, len(records), chunk_size): - chunk = records[i : i + chunk_size] + for i in range(0, len(records), batch_size): + batch = records[i : i + batch_size] response = await async_bulk( client=self._client, @@ -376,14 +388,14 @@ async def log_async( json_body=bulk_class( tags=tags, metadata=metadata, - records=[creation_class.from_client(r) for r in chunk], + records=[creation_class.from_client(r) for r in batch], ), ) processed += response.parsed.processed failed += response.parsed.failed - progress_bar.update(task, advance=len(chunk)) + progress_bar.update(task, advance=len(batch)) # TODO: improve logging policy in library if verbose: diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 822a8e6607..ba4edddd83 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -228,6 +228,23 @@ def test_log_passing_empty_records_list(mocked_client): api.log(records=[], name="ds") +def test_log_deprecated_chunk_size(mocked_client): + dataset_name = "test_log_deprecated_chunk_size" + mocked_client.delete(f"/api/datasets/{dataset_name}") + record = rg.TextClassificationRecord(text="My text") + with pytest.warns(FutureWarning, match="`chunk_size`.*`batch_size`"): + api.log(records=[record], name=dataset_name, chunk_size=100) + + +def test_large_batch_size_warning(mocked_client, caplog: pytest.LogCaptureFixture): + dataset_name = "test_large_batch_size_warning" + mocked_client.delete(f"/api/datasets/{dataset_name}") + record = rg.TextClassificationRecord(text="My text") + api.log(records=[record], name=dataset_name, batch_size=10000) + assert len(caplog.record_tuples) == 1 + assert "batch size is noticeably large" in caplog.record_tuples[0][2] + + def test_log_background(mocked_client): """Verify that logs can be delayed via the background parameter.""" dataset_name = "test_log_background"