Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate chunk_size in favor of batch_size for rg.log #2455

Merged
merged 2 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/argilla/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def log(
name: str,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.

Expand All @@ -124,10 +125,11 @@ def log(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future``
object. You probably want to set ``verbose`` to False in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.

Returns:
Summary of the response from the REST API.
Expand All @@ -152,9 +154,10 @@ def log(
name=name,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
background=background,
chunk_size=chunk_size,
)


Expand All @@ -163,8 +166,9 @@ async def log_async(
name: str,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
"""Logs Records to argilla with asyncio.

Expand All @@ -173,8 +177,9 @@ async def log_async(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
chunk_size: DEPRECATED! Use `batch_size` instead.

Returns:
Summary of the response from the REST API
Expand All @@ -194,8 +199,9 @@ async def log_async(
name=name,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
)


Expand Down
40 changes: 26 additions & 14 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class Argilla:
"""

# Larger sizes will trigger a warning
_MAX_CHUNK_SIZE = 5000
_MAX_BATCH_SIZE = 5000
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down Expand Up @@ -256,9 +256,10 @@ def log(
name: str,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
background: bool = False,
chunk_size: Optional[int] = None,
) -> Union[BulkResponse, Future]:
"""Logs Records to argilla.

Expand All @@ -269,11 +270,12 @@ def log(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
background: If True, we will NOT wait for the logging process to finish and return
an ``asyncio.Future`` object. You probably want to set ``verbose`` to False
in that case.
chunk_size: DEPRECATED! Use `batch_size` instead.

Returns:
Summary of the response from the REST API.
Expand All @@ -286,8 +288,9 @@ def log(
name=name,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
batch_size=batch_size,
verbose=verbose,
chunk_size=chunk_size,
)
if background:
return future
Expand All @@ -303,8 +306,9 @@ async def log_async(
name: str,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
batch_size: int = 500,
verbose: bool = True,
chunk_size: Optional[int] = None,
) -> BulkResponse:
"""Logs Records to argilla with asyncio.

Expand All @@ -313,8 +317,9 @@ async def log_async(
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
batch_size: The batch size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
chunk_size: DEPRECATED! Use `batch_size` instead.

Returns:
Summary of the response from the REST API
Expand All @@ -335,11 +340,18 @@ async def log_async(
" https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html"
)

if chunk_size > self._MAX_CHUNK_SIZE:
if chunk_size is not None:
warnings.warn(
"The argument `chunk_size` is deprecated and will be removed in a future"
" version. Please use `batch_size` instead.",
FutureWarning,
)
batch_size = chunk_size

if batch_size > self._MAX_BATCH_SIZE:
_LOGGER.warning(
"""The introduced chunk size is noticeably large, timeout errors may occur.
Consider a chunk size smaller than %s""",
self._MAX_CHUNK_SIZE,
"The requested batch size is noticeably large, timeout errors may occur. "
f"Consider a batch size smaller than {self._MAX_BATCH_SIZE}",
)

if isinstance(records, Record.__args__):
Expand Down Expand Up @@ -367,23 +379,23 @@ async def log_async(
with Progress() as progress_bar:
task = progress_bar.add_task("Logging...", total=len(records), visible=verbose)

for i in range(0, len(records), chunk_size):
chunk = records[i : i + chunk_size]
for i in range(0, len(records), batch_size):
batch = records[i : i + batch_size]

response = await async_bulk(
client=self._client,
name=name,
json_body=bulk_class(
tags=tags,
metadata=metadata,
records=[creation_class.from_client(r) for r in chunk],
records=[creation_class.from_client(r) for r in batch],
),
)

processed += response.parsed.processed
failed += response.parsed.failed

progress_bar.update(task, advance=len(chunk))
progress_bar.update(task, advance=len(batch))

# TODO: improve logging policy in library
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
if verbose:
Expand Down
8 changes: 8 additions & 0 deletions tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def test_log_passing_empty_records_list(mocked_client):
api.log(records=[], name="ds")


def test_log_deprecated_chunk_size(mocked_client):
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
dataset_name = "test_log_deprecated_chunk_size"
mocked_client.delete(f"/api/datasets/{dataset_name}")
record = rg.TextClassificationRecord(text="My text")
with pytest.warns(FutureWarning, match="`chunk_size`.*`batch_size`"):
api.log(records=[record], name=dataset_name, chunk_size=100)


def test_log_background(mocked_client):
"""Verify that logs can be delayed via the background parameter."""
dataset_name = "test_log_background"
Expand Down