Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] [REFACTOR] optimise and refactor SDK ingestion methods #5107

Merged
merged 37 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
10965d3
test: update tests for refactored mapping method
burtenshaw Jun 24, 2024
a416a2f
refactor: introduce independent mapping method and move logic to befo…
burtenshaw Jun 24, 2024
35db9f6
docs: update all doc strings in dataset records
burtenshaw Jun 25, 2024
eae088b
chore: improve typing and docs on type
burtenshaw Jun 25, 2024
4490d11
docs: wrong method in records api reference
burtenshaw Jun 25, 2024
b5b3396
feat: add exception for record ingestion
burtenshaw Jun 26, 2024
ffeb0b0
refactor: improve explainabilitity and readability in ingestion code …
burtenshaw Jun 26, 2024
594283e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
16f14d1
enhancement: move mapping out of record loop
burtenshaw Jun 26, 2024
db08548
[REFACTOR] Avoid autofetch when accessing settings (#5112)
frascuchon Jun 27, 2024
5f06e20
Merge branch 'spike/mapping-to-tuple' of https://github.com/argilla-i…
burtenshaw Jun 27, 2024
05df51a
enhancement: use just one progress bar
burtenshaw Jun 27, 2024
863dde2
chore: update typing of mapping
burtenshaw Jun 27, 2024
bf9e864
fix: move render mapping into infer record method
burtenshaw Jun 27, 2024
07aa249
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
8a6d484
fix: align add records parameters with render function
burtenshaw Jun 27, 2024
0b623fd
feat: implement ingestion mapping as class
burtenshaw Jul 2, 2024
14faccf
feat: use ingestion mapping class in dataset records not dataset records
burtenshaw Jul 2, 2024
e2bfc88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
8889c0a
chore: tidy imports
burtenshaw Jul 2, 2024
7ad5075
Merge branch 'spike/mapping-to-tuple' of https://github.com/argilla-i…
burtenshaw Jul 2, 2024
63e0f7b
docs: update mapping parameters in how to guides
burtenshaw Jul 2, 2024
ecbdd4e
test: broaden suggestion mapping in test
burtenshaw Jul 2, 2024
99235b2
feat: extract dot notation with regex not string splitting
burtenshaw Jul 3, 2024
3ca8932
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
3eb8c0d
docs: typo in docs
burtenshaw Jul 3, 2024
b10fbe8
feat: improve record switch in ingest method
burtenshaw Jul 3, 2024
db27e1b
feat: refactor id mapping away from dict to type
burtenshaw Jul 3, 2024
716dfa8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
6a7e9eb
feat: use class methods for type and parameter values
burtenshaw Jul 3, 2024
ca1a394
Merge branch 'spike/mapping-to-tuple' of https://github.com/argilla-i…
burtenshaw Jul 3, 2024
40ff5b6
[REFACTOR] generate default mapping and extends it with custom mappin…
frascuchon Jul 4, 2024
667ad54
refactor: migrate mapper into module from file
burtenshaw Jul 4, 2024
94988eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
2769793
fix: raise error not warn when mapped attribute is unknown
burtenshaw Jul 4, 2024
d83cf7a
Merge branch 'spike/mapping-to-tuple' of https://github.com/argilla-i…
burtenshaw Jul 4, 2024
0724166
Merge branch 'develop' into spike/mapping-to-tuple
frascuchon Jul 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions argilla/docs/how_to_guides/record.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,24 +318,31 @@ Suggestions refer to suggested responses (e.g. model predictions) that you can a
You can add suggestions as a dictionary, where the keys correspond to the `name`s of the labels that were configured for your dataset. Remember that you can also use the `mapping` parameter to specify the data structure.

```python
# Add records to the dataset with the label 'my_label'
# Add records to the dataset with the label 'my_label'
burtenshaw marked this conversation as resolved.
Show resolved Hide resolved
data = [
{
"question": "Do you need oxygen to breathe?",
"answer": "Yes",
"my_label.suggestion": "positive",
"my_label.suggestion.score": 0.9,
"my_label.suggestion.agent": "model_name"
"label": "positive",
"score": 0.9,
"agent": "model_name",
},
{
"question": "What is the boiling point of water?",
"answer": "100 degrees Celsius",
"my_label.suggestion": "negative",
"my_label.suggestion.score": 0.9,
"my_label.suggestion.agent": "model_name"
"label": "negative",
"score": 0.9,
"agent": "model_name",
},
]
dataset.records.log(data)
dataset.records.log(
data=data,
mapping={
"label": "my_label",
"score": "my_label.suggestion.score",
"agent": "my_label.suggestion.agent",
},
)
```

### Responses
Expand Down Expand Up @@ -385,15 +392,15 @@ If your dataset includes some annotations, you can add those to the records as y
{
"question": "Do you need oxygen to breathe?",
"answer": "Yes",
"my_label.response": "positive",
"label": "positive",
},
{
"question": "What is the boiling point of water?",
"answer": "100 degrees Celsius",
"my_label.response": "negative",
"label": "negative",
},
]
dataset.records.log(data, user_id=user.id)
dataset.records.log(data, user_id=user.id, mapping={"label": "my_label.response"})
```

## List records
Expand All @@ -415,7 +422,7 @@ for record in dataset.records(

# Access the responses of the record
for response in record.responses:
print(record.["<question_name>"].value)
print(record["<question_name>"].value)
```

## Update records
Expand Down Expand Up @@ -460,8 +467,8 @@ dataset.records.log(records=updated_data)

for record in dataset.records():

record.vectors["new_vector"] = [...]
record.vector["v"] = [...]
record.vectors["new_vector"] = [ 0, 1, 2, 3, 4, 5 ]
record.vector["v"] = [ 0.1, 0.2, 0.3 ]

updated_records.append(record)

Expand Down
2 changes: 1 addition & 1 deletion argilla/docs/reference/argilla/records/records.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The `Record` object is used to represent a single record in Argilla. It contains
To create records, you can use the `Record` class and pass it to the `Dataset.records.log` method. The `Record` class requires a `fields` parameter, which is a dictionary of field names and values. The field names must match the field names in the dataset's `Settings` object to be accepted.

```python
dataset.records.add(
dataset.records.log(
records=[
rg.Record(
fields={"text": "Hello World, how are you?"},
Expand Down
1 change: 1 addition & 0 deletions argilla/src/argilla/_exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from argilla._exceptions._metadata import * # noqa: F403
from argilla._exceptions._serialization import * # noqa: F403
from argilla._exceptions._settings import * # noqa: F403
from argilla._exceptions._records import * # noqa: F403
19 changes: 19 additions & 0 deletions argilla/src/argilla/_exceptions/_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._exceptions._base import ArgillaErrorBase


class RecordsIngestionError(ArgillaErrorBase):
pass
2 changes: 1 addition & 1 deletion argilla/src/argilla/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __call__(self, name: str, workspace: Optional[Union["Workspace", str]] = Non

for dataset in workspace.datasets:
if dataset.name == name:
return dataset
return dataset.get()
warnings.warn(f"Dataset {name} not found. Creating a new dataset. Do `dataset.create()` to create the dataset.")
return Dataset(name=name, workspace=workspace, client=self._client, **kwargs)

Expand Down
9 changes: 6 additions & 3 deletions argilla/src/argilla/datasets/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def records(self) -> "DatasetRecords":

@property
def settings(self) -> Settings:
if self._is_published() and self._settings.is_outdated:
self._settings.get()
return self._settings

@settings.setter
Expand Down Expand Up @@ -142,6 +140,11 @@ def schema(self) -> dict:
# Core methods #
#####################

def get(self) -> "Dataset":
super().get()
self.settings.get()
return self

def exists(self) -> bool:
"""Checks if the dataset exists on the server

Expand Down Expand Up @@ -185,7 +188,7 @@ def _publish(self) -> "Dataset":
self._settings.create()
self._api.publish(dataset_id=self._model.id)

return self.get() # type: ignore
return self.get()

def _workspace_id_from_name(self, workspace: Optional[Union["Workspace", str]]) -> UUID:
if workspace is None:
Expand Down
161 changes: 37 additions & 124 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict

from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
from uuid import UUID
Expand All @@ -21,16 +20,13 @@

from argilla._api import RecordsAPI
from argilla._helpers import LoggingMixin
from argilla._models import RecordModel, MetadataValue, VectorValue, FieldValue
from argilla._models import RecordModel
from argilla._exceptions import RecordsIngestionError
from argilla.client import Argilla
from argilla.records._io import GenericIO, HFDataset, HFDatasetsIO, JsonIO
from argilla.records._mapping import IngestedRecordMapper
from argilla.records._resource import Record
from argilla.records._search import Query
from argilla.responses import Response
from argilla.settings import TextField, VectorField
from argilla.settings._metadata import MetadataPropertyBase
from argilla.settings._question import QuestionPropertyBase
from argilla.suggestions import Suggestion

if TYPE_CHECKING:
from argilla.datasets import Dataset
Expand Down Expand Up @@ -188,8 +184,8 @@ def __call__(
self._validate_vector_names(vector_names=with_vectors)

return DatasetRecordsIterator(
self.__dataset,
self.__client,
dataset=self.__dataset,
client=self.__client,
query=query,
batch_size=batch_size,
start_offset=start_offset,
Expand All @@ -208,7 +204,7 @@ def __repr__(self) -> str:
def log(
self,
records: Union[List[dict], List[Record], HFDataset],
mapping: Optional[Dict[str, str]] = None,
mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
user_id: Optional[UUID] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> "DatasetRecords":
Expand All @@ -222,12 +218,12 @@ def log(
If records are defined as a dictionaries or a dataset, the keys/ column names should correspond to the
fields in the Argilla dataset's fields and questions. `id` should be provided to identify the records when updating.
mapping: A dictionary that maps the keys/ column names in the records to the fields or questions in the Argilla dataset.
To assign an incoming key or column to multiple fields or questions, provide a list or tuple of field or question names.
user_id: The user id to be associated with the records' response. If not provided, the current user id is used.
batch_size: The number of records to send in each batch. The default is 256.

Returns:
A list of Record objects representing the updated records.

"""
record_models = self._ingest_records(records=records, mapping=mapping, user_id=user_id or self.__client.me.id)
batch_size = self._normalize_batch_size(
Expand All @@ -238,8 +234,12 @@ def log(

created_or_updated = []
records_updated = 0

for batch in tqdm(
iterable=range(0, len(records), batch_size), desc="Adding and updating records", unit="batch"
iterable=range(0, len(records), batch_size),
desc="Sending records...",
total=len(records) // batch_size,
unit="batch",
):
self._log_message(message=f"Sending records from {batch} to {batch + batch_size}.")
batch_records = record_models[batch : batch + batch_size]
Expand Down Expand Up @@ -357,26 +357,36 @@ def to_datasets(self) -> HFDataset:

def _ingest_records(
self,
records: Union[List[Dict[str, Any]], Dict[str, Any], List[Record], Record, HFDataset],
mapping: Optional[Dict[str, str]] = None,
records: Union[List[Dict[str, Any]], List[Record], HFDataset],
mapping: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
user_id: Optional[UUID] = None,
) -> List[RecordModel]:
"""Ingests records from a list of dictionaries, a Hugging Face Dataset, or a list of Record objects."""

if len(records) == 0:
raise ValueError("No records provided to ingest.")

if HFDatasetsIO._is_hf_dataset(dataset=records):
records = HFDatasetsIO._record_dicts_from_datasets(dataset=records)
if all(map(lambda r: isinstance(r, dict), records)):
# Records as flat dicts of values to be matched to questions as suggestion or response
records = [self._infer_record_from_mapping(data=r, mapping=mapping, user_id=user_id) for r in records] # type: ignore
elif all(map(lambda r: isinstance(r, Record), records)):
for record in records:
record.dataset = self.__dataset
else:
raise ValueError(
"Records should be a a list Record instances, "
"a Hugging Face Dataset, or a list of dictionaries representing the records."
)
return [record.api_model() for record in records]

ingested_records = []
record_mapper = IngestedRecordMapper(mapping=mapping, dataset=self.__dataset, user_id=user_id)
for record in records:
try:
if not isinstance(record, Record):
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
record = record_mapper(data=record)
elif isinstance(record, Record):
record.dataset = self.__dataset
else:
raise ValueError(
"Records should be a a list Record instances, "
"a Hugging Face Dataset, or a list of dictionaries representing the records."
f"Found a record of type {type(record)}: {record}."
)
except Exception as e:
raise RecordsIngestionError(f"Failed to ingest record from dict {record}: {e}")
ingested_records.append(record.api_model())
return ingested_records

def _normalize_batch_size(self, batch_size: int, records_length, max_value: int):
norm_batch_size = min(batch_size, records_length, max_value)
Expand All @@ -397,100 +407,3 @@ def _validate_vector_names(self, vector_names: Union[List[str], str]) -> None:
continue
if vector_name not in self.__dataset.schema:
raise ValueError(f"Vector field {vector_name} not found in dataset schema.")

def _infer_record_from_mapping(
self,
data: dict,
mapping: Optional[Dict[str, str]] = None,
user_id: Optional[UUID] = None,
) -> "Record":
"""Converts a mapped record dictionary to a Record object for use by the add or update methods.
Args:
dataset: The dataset object to which the record belongs.
data: A dictionary representing the record.
mapping: A dictionary mapping source data keys to Argilla fields, questions, and ids.
user_id: The user id to associate with the record responses.
Returns:
A Record object.
"""
record_id: Optional[str] = None

fields: Dict[str, FieldValue] = {}
vectors: Dict[str, VectorValue] = {}
metadata: Dict[str, MetadataValue] = {}

responses: List[Response] = []
suggestion_values: Dict[str, dict] = defaultdict(dict)

schema = self.__dataset.schema

for attribute, value in data.items():
schema_item = schema.get(attribute)
attribute_type = None
sub_attribute = None

# Map source data keys using the mapping
if mapping and attribute in mapping:
attribute_mapping = mapping.get(attribute)
attribute_mapping = attribute_mapping.split(".")
attribute = attribute_mapping[0]
schema_item = schema.get(attribute)
if len(attribute_mapping) > 1:
attribute_type = attribute_mapping[1]
if len(attribute_mapping) > 2:
sub_attribute = attribute_mapping[2]
elif schema_item is mapping is None and attribute != "id":
warnings.warn(
message=f"""Record attribute {attribute} is not in the schema so skipping.
Define a mapping to map source data fields to Argilla Fields, Questions, and ids
"""
)
continue

if attribute == "id":
record_id = value
continue

# Add suggestion values to the suggestions
if attribute_type == "suggestion":
if sub_attribute in ["score", "agent"]:
suggestion_values[attribute][sub_attribute] = value

elif sub_attribute is None:
suggestion_values[attribute].update(
{"value": value, "question_name": attribute, "question_id": schema_item.id}
)
else:
warnings.warn(
message=f"Record attribute {sub_attribute} is not a valid suggestion sub_attribute so skipping."
)
continue

# Assign the value to question, field, or response based on schema item
if isinstance(schema_item, TextField):
fields[attribute] = value
elif isinstance(schema_item, QuestionPropertyBase) and attribute_type == "response":
responses.append(Response(question_name=attribute, value=value, user_id=user_id))
elif isinstance(schema_item, QuestionPropertyBase) and attribute_type is None:
suggestion_values[attribute].update(
{"value": value, "question_name": attribute, "question_id": schema_item.id}
)
elif isinstance(schema_item, VectorField):
vectors[attribute] = value
elif isinstance(schema_item, MetadataPropertyBase):
metadata[attribute] = value
else:
warnings.warn(message=f"Record attribute {attribute} is not in the schema or mapping so skipping.")
continue

suggestions = [Suggestion(**suggestion_dict) for suggestion_dict in suggestion_values.values()]

return Record(
id=record_id,
fields=fields,
vectors=vectors,
metadata=metadata,
suggestions=suggestions,
responses=responses,
_dataset=self.__dataset,
)
Loading