Skip to content

Commit

Permalink
feat: Deprecate chunk_size in favor of batch_size for rg.log (#…
Browse files Browse the repository at this point in the history
…2455)

Closes #2453

## Pull Request overview
* Deprecate `chunk_size` in favor of `batch_size` in `rg.log`:
  * Move `chunk_size` to the end of all related signatures.
  * Set `chunk_size` default to None and update the typing accordingly.
  * Introduce `batch_size` in the old position in the signature.
  * Update docstrings accordingly.
  * Introduce a `FutureWarning` if `chunk_size` is used.
* Introduce test showing that `rg.log(..., chunk_size=100)` indeed
throws a warning.
* Update a warning to no longer include a newline and a lot of spaces
(see first comment of this PR)

## Details
Note that this deprecation is non-breaking: Code that uses `chunk_size`
will still work, as `batch_size = chunk_size` after a FutureWarning is
given, if `chunk_size` is not `None`.

## Discussion
* Should I use a FutureWarning? Or a DeprecationWarning? Or a
PendingDeprecationWarning? The last two make sense, but they are
[ignored by
default](https://docs.python.org/3/library/warnings.html#default-warning-filter),
I'm afraid.
* Is the deprecation message in the format that we like?

---

**Type of change**
- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

I introduced a test, and ran all tests.

**Checklist**

- [x] I have merged the original branch into my forked branch
- [x] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works

---

- Tom Aarsen
  • Loading branch information
tomaarsen authored Mar 6, 2023
1 parent b3b897a commit 3ebea76
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 20 deletions.
18 changes: 12 additions & 6 deletions src/argilla/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def log(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.
Expand All @@ -127,10 +128,11 @@ def log(
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future``
object. You probably want to set ``verbose`` to False in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.
Returns:
Summary of the response from the REST API.
Expand All @@ -156,9 +158,10 @@ def log(
workspace=workspace,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
background=background,
chunk_size=chunk_size,
)


Expand All @@ -168,8 +171,9 @@ async def log_async(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
"""Logs Records to argilla with asyncio.
Expand All @@ -180,8 +184,9 @@ async def log_async(
env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
chunk_size: DEPRECATED! Use `batch_size` instead.
Returns:
Summary of the response from the REST API
Expand All @@ -202,8 +207,9 @@ async def log_async(
workspace=workspace,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
)


Expand Down
40 changes: 26 additions & 14 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Argilla:
"""

# Larger sizes will trigger a warning
_MAX_CHUNK_SIZE = 5000
_MAX_BATCH_SIZE = 5000

def __init__(
self,
Expand Down Expand Up @@ -267,9 +267,10 @@ def log(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.
Expand All @@ -280,11 +281,12 @@ def log(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
background: If True, we will NOT wait for the logging process to finish and return
an ``asyncio.Future`` object. You probably want to set ``verbose`` to False
in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.
Returns:
Summary of the response from the REST API.
Expand All @@ -300,8 +302,9 @@ def log(
name=name,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
)
if background:
return future
Expand All @@ -318,8 +321,9 @@ async def log_async(
workspace: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
"""Logs Records to argilla with asyncio.
Expand All @@ -328,8 +332,9 @@ async def log_async(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
chunk_size: DEPRECATED! Use `batch_size` instead.
Returns:
Summary of the response from the REST API
Expand All @@ -353,11 +358,18 @@ async def log_async(
" https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html"
)

if chunk_size > self._MAX_CHUNK_SIZE:
if chunk_size is not None:
warnings.warn(
"The argument `chunk_size` is deprecated and will be removed in a future"
" version. Please use `batch_size` instead.",
FutureWarning,
)
batch_size = chunk_size

if batch_size > self._MAX_BATCH_SIZE:
_LOGGER.warning(
"""The introduced chunk size is noticeably large, timeout errors may occur.
Consider a chunk size smaller than %s""",
self._MAX_CHUNK_SIZE,
"The requested batch size is noticeably large, timeout errors may occur. "
f"Consider a batch size smaller than {self._MAX_BATCH_SIZE}",
)

if isinstance(records, Record.__args__):
Expand Down Expand Up @@ -385,23 +397,23 @@ async def log_async(
with Progress() as progress_bar:
task = progress_bar.add_task("Logging...", total=len(records), visible=verbose)

for i in range(0, len(records), chunk_size):
chunk = records[i : i + chunk_size]
for i in range(0, len(records), batch_size):
batch = records[i : i + batch_size]

response = await async_bulk(
client=self._client,
name=name,
json_body=bulk_class(
tags=tags,
metadata=metadata,
records=[creation_class.from_client(r) for r in chunk],
records=[creation_class.from_client(r) for r in batch],
),
)

processed += response.parsed.processed
failed += response.parsed.failed

progress_bar.update(task, advance=len(chunk))
progress_bar.update(task, advance=len(batch))

# TODO: improve logging policy in library
if verbose:
Expand Down
17 changes: 17 additions & 0 deletions tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ def test_log_passing_empty_records_list(mocked_client):
api.log(records=[], name="ds")


def test_log_deprecated_chunk_size(mocked_client):
dataset_name = "test_log_deprecated_chunk_size"
mocked_client.delete(f"/api/datasets/{dataset_name}")
record = rg.TextClassificationRecord(text="My text")
with pytest.warns(FutureWarning, match="`chunk_size`.*`batch_size`"):
api.log(records=[record], name=dataset_name, chunk_size=100)


def test_large_batch_size_warning(mocked_client, caplog: pytest.LogCaptureFixture):
dataset_name = "test_large_batch_size_warning"
mocked_client.delete(f"/api/datasets/{dataset_name}")
record = rg.TextClassificationRecord(text="My text")
api.log(records=[record], name=dataset_name, batch_size=10000)
assert len(caplog.record_tuples) == 1
assert "batch size is noticeably large" in caplog.record_tuples[0][2]


def test_log_background(mocked_client):
"""Verify that logs can be delayed via the background parameter."""
dataset_name = "test_log_background"
Expand Down

0 comments on commit 3ebea76

Please sign in to comment.