Skip to content

Commit

Permalink
Merge branch 'feat/add-basic-database-support' of https://github.com/…
Browse files Browse the repository at this point in the history
…argilla-io/argilla into feat/add-basic-database-support
  • Loading branch information
keithCuniah committed Mar 10, 2023
2 parents 6e30453 + 3ba05b8 commit 302dada
Show file tree
Hide file tree
Showing 19 changed files with 174 additions and 637 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- Add the fields to retrieve when loading the data from argilla. `rg.load` takes too long because of the vector field, even when users don't need it. Closes [#2398](https://github.com/argilla-io/argilla/issues/2398)


### Removed

- Removing some data scan deprecated endpoints used by old clients. This change will break compatibility with client `<v1.3.0`
- Stop using old scan deprecated endpoints in python client. This logic will break client compatibility with server version `<1.3.0`

3 changes: 2 additions & 1 deletion docs/_source/guides/log_load_and_prepare_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@
" query=\"my AND query\",\n",
" limit=42\n",
" ids=[\"id1\", \"id2\", \"id3\"],\n",
" vectors=(\"vector1\", [0, 42, 1957]), \n",
" vectors=[\"vector1\", \"vector2\", \"vector3\"],\n",
" fields=[\"id\",\"inputs\",\"text\"]\n",
" sort=[(\"event_timestamp\", \"desc\")]\n",
")"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/_source/reference/telemetry.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We do not collect any piece of information related to the source data you store
## Information reported
The following usage and error information is reported:

* The code of the raised error
* The code of the raised error and the entity type related to the error, if any (Dataset, Workspace,...)
* The `user-agent` and `accept-language` http headers
* Task name and number of records for bulk operations
* An anonymous generated user uuid
Expand Down
1 change: 1 addition & 0 deletions pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ Closes #<issue_number>
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
3 changes: 3 additions & 0 deletions src/argilla/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@
_OLD_WORKSPACE_HEADER_NAME = "X-Rubrix-Workspace"

ES_INDEX_REGEX_PATTERN = r"^(?!-|_)[a-z0-9-_]+$"


DEFAULT_TELEMETRY_KEY = "C6FkcaoCbt78rACAgvyBxGBcMB3dM3nn"
4 changes: 4 additions & 0 deletions src/argilla/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def load(
id_from: Optional[str] = None,
batch_size: int = 250,
as_pandas=None,
fields: Optional[List[str]] = None,
) -> Dataset:
"""Loads a argilla dataset.
Expand All @@ -244,6 +245,8 @@ def load(
size may help avoid timeouts.
as_pandas: DEPRECATED! To get a pandas DataFrame do
``rg.load('my_dataset').to_pandas()``.
fields: A list of fields to retrieve. If not provided, all fields will be retrieved.
``rg.load('my_dataset', fields=['id', 'text'])``
Returns:
A argilla dataset.
Expand Down Expand Up @@ -277,6 +280,7 @@ def load(
id_from=id_from,
batch_size=batch_size,
as_pandas=as_pandas,
fields=fields,
)


Expand Down
141 changes: 25 additions & 116 deletions src/argilla/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def load(
id_from: Optional[str] = None,
batch_size: int = 250,
as_pandas=None,
fields: Optional[List[str]] = None,
) -> Dataset:
"""Loads a argilla dataset.
Expand All @@ -491,6 +492,8 @@ def load(
can be used to load using batches.
as_pandas: DEPRECATED! To get a pandas DataFrame do
``rg.load('my_dataset').to_pandas()``.
fields: If provided, only the given fields will be retrieved.
``rg.load('my_dataset', fields=['text'])``
Returns:
A argilla dataset.
Expand All @@ -512,48 +515,17 @@ def load(
"If you want a pandas DataFrame do `rg.load('my_dataset').to_pandas()`.",
)

try:
return self._load_records_new_fashion(
name=name,
query=query,
vector=vector,
ids=ids,
limit=limit,
sort=sort,
id_from=id_from,
batch_size=batch_size,
)
except ApiCompatibilityError as err: # Api backward compatibility
from argilla import __version__ as version

warnings.warn(
message=(
f"Using python client argilla=={version},"
f" however deployed server version is {err.api_version}."
" This might lead to compatibility issues.\n"
f" Preferably, update your server version to {version}"
" or downgrade your Python API at the loss"
" of functionality and robustness via\n"
f"`pip install argilla=={err.api_version}`"
),
category=UserWarning,
)
if batch_size is not None:
warnings.warn(
message="The `batch_size` parameter is not supported"
f" on server version {err.api_version}. Consider"
f" updating your server version to {version} to"
" take advantage of this functionality."
)

return self._load_records_old_fashion(
name=name,
query=query,
ids=ids,
limit=limit,
sort=sort,
id_from=id_from,
)
return self._load_records_internal(
name=name,
query=query,
vector=vector,
ids=ids,
limit=limit,
sort=sort,
id_from=id_from,
fields=fields,
batch_size=batch_size,
)

def dataset_metrics(self, name: str) -> List[MetricInfo]:
response = datasets_api.get_dataset(self._client, name)
Expand Down Expand Up @@ -647,63 +619,7 @@ def rule_metrics_for_dataset(self, dataset: str, rule: LabelingRule) -> Labeling

return LabelingRuleMetricsSummary.parse_obj(response.parsed)

def _load_records_old_fashion(
self,
name: str,
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Dataset:
from argilla.client.sdk.text2text import api as text2text_api
from argilla.client.sdk.text2text.models import Text2TextQuery
from argilla.client.sdk.text_classification import (
api as text_classification_api,
)
from argilla.client.sdk.token_classification import (
api as token_classification_api,
)

response = datasets_api.get_dataset(client=self._client, name=name)
task = response.parsed.task

task_config = {
TaskType.text_classification: (
text_classification_api.data,
TextClassificationQuery,
DatasetForTextClassification,
),
TaskType.token_classification: (
token_classification_api.data,
TokenClassificationQuery,
DatasetForTokenClassification,
),
TaskType.text2text: (
text2text_api.data,
Text2TextQuery,
DatasetForText2Text,
),
}

try:
get_dataset_data, request_class, dataset_class = task_config[task]
except KeyError:
raise ValueError(
f"Load method not supported for the '{task}' task. Supported tasks: "
f"{[TaskType.text_classification, TaskType.token_classification, TaskType.text2text]}"
)
response = get_dataset_data(
client=self._client,
name=name,
request=request_class(ids=ids, query_text=query),
limit=limit,
id_from=id_from,
)

records = [sdk_record.to_client() for sdk_record in response.parsed]
return dataset_class(records)

def _load_records_new_fashion(
def _load_records_internal(
self,
name: str,
query: Optional[str] = None,
Expand All @@ -712,24 +628,16 @@ def _load_records_new_fashion(
limit: Optional[int] = None,
sort: Optional[List[Tuple[str, str]]] = None,
id_from: Optional[str] = None,
fields: Optional[List[str]] = None,
batch_size: int = 250,
) -> Dataset:
dataset = self.datasets.find_by_name(name=name)
task = dataset.task

task_config = {
TaskType.text_classification: (
SdkTextClassificationRecord,
DatasetForTextClassification,
),
TaskType.token_classification: (
SdkTokenClassificationRecord,
DatasetForTokenClassification,
),
TaskType.text2text: (
SdkText2TextRecord,
DatasetForText2Text,
),
TaskType.text_classification: (SdkTextClassificationRecord, DatasetForTextClassification),
TaskType.token_classification: (SdkTokenClassificationRecord, DatasetForTokenClassification),
TaskType.text2text: (SdkText2TextRecord, DatasetForText2Text),
}

try:
Expand All @@ -744,10 +652,7 @@ def _load_records_new_fashion(
if sort is not None:
_LOGGER.warning("Results are sorted by vector similarity, so 'sort' parameter is ignored.")

vector_search = VectorSearch(
name=vector[0],
value=vector[1],
)
vector_search = VectorSearch(name=vector[0], value=vector[1])
results = self.search.search_records(
name=name,
task=task,
Expand All @@ -756,11 +661,15 @@ def _load_records_new_fashion(
query_text=query,
vector=vector_search,
)

return dataset_class(results.records)

if fields:
fields.extend(["id", "text", "tokens", "inputs"])

records = self.datasets.scan(
name=name,
projection={"*"},
projection=set(fields or "*"),
limit=limit,
sort=sort,
id_from=id_from,
Expand Down
45 changes: 0 additions & 45 deletions src/argilla/client/sdk/text2text/api.py

This file was deleted.

19 changes: 0 additions & 19 deletions src/argilla/client/sdk/text_classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@
)


def data(
client: AuthenticatedClient,
name: str,
request: Optional[TextClassificationQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]:
path = f"/api/datasets/{name}/TextClassification/data"
params = build_param_dict(id_from, limit)

with client.stream(
method="POST",
path=path,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(response=response, data_type=TextClassificationRecord)


def add_dataset_labeling_rule(
client: AuthenticatedClient,
name: str,
Expand Down
49 changes: 0 additions & 49 deletions src/argilla/client/sdk/token_classification/api.py

This file was deleted.

Loading

0 comments on commit 302dada

Please sign in to comment.