Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose allow overlapping property #4697

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]()

### Added

- Added `allow_overlapping` parameter for span questions. ([#4697](https://github.com/argilla-io/argilla/pull/4697))

## [1.26.1](https://github.com/argilla-io/argilla/compare/v1.26.0...v1.26.1)

### Added
Expand Down
2 changes: 1 addition & 1 deletion environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ dependencies:
- ipynbname>=2023.2.0.0
- httpx~=0.26.0
# For now we can just install argilla-server from the GitHub repo
- git+https://github.com/argilla-io/argilla-server.git
- git+https://github.com/argilla-io/argilla-server.git@feat/overlapped-span-questions
# install Argilla in editable mode
- -e .[listeners]
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
server = ["argilla-server ~= 1.26.1"]
server-postgresql = ["argilla-server[postgresql] ~= 1.26.1"]
server = ["argilla-server ~= 1.27.0.dev0"]
server-postgresql = ["argilla-server[postgresql] ~= 1.27.0.dev0"]
listeners = ["schedule ~= 1.1.0", "prodict ~= 0.8.0"]
integrations = [
"PyYAML >= 5.4.1,< 6.1.0", # Required by `argilla.client.feedback.config` just used in `HuggingFaceDatasetMixin`
Expand Down
2 changes: 2 additions & 0 deletions src/argilla/client/feedback/schemas/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class SpanQuestion(QuestionSchema):
field: str = Field(..., description="The field in the input that the user will be asked to annotate.")
labels: Union[Dict[str, str], conlist(Union[str, SpanLabelOption], min_items=1, unique_items=True)]
visible_labels: Union[conint(ge=3), None] = _DEFAULT_MAX_VISIBLE_LABELS
allow_overlapping: bool = Field(False, description="Configure span to support overlap")

@validator("labels", pre=True)
def parse_labels_dict(cls, labels) -> List[SpanLabelOption]:
Expand Down Expand Up @@ -408,6 +409,7 @@ def server_settings(self) -> Dict[str, Any]:
"type": self.type,
"field": self.field,
"visible_options": self.visible_labels,
"allow_overlapping": self.allow_overlapping,
"options": [label.dict() for label in self.labels],
}

Expand Down
9 changes: 6 additions & 3 deletions src/argilla/client/feedback/schemas/remote/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def to_local(self) -> SpanQuestion:
required=self.required,
labels=self.labels,
visible_labels=self.visible_labels,
allow_overlapping=self.allow_overlapping,
)

@classmethod
Expand All @@ -168,14 +169,16 @@ def _parse_options_from_api(cls, options: List[Dict[str, str]]) -> List[SpanLabe

@classmethod
def from_api(cls, payload: "FeedbackQuestionModel") -> "RemoteSpanQuestion":
question_settings = payload.settings
return RemoteSpanQuestion(
id=payload.id,
name=payload.name,
title=payload.title,
field=payload.settings["field"],
field=question_settings["field"],
required=payload.required,
visible_labels=payload.settings["visible_options"],
labels=cls._parse_options_from_api(payload.settings["options"]),
visible_labels=question_settings["visible_options"],
labels=cls._parse_options_from_api(question_settings["options"]),
allow_overlapping=question_settings["allow_overlapping"],
)


Expand Down
96 changes: 93 additions & 3 deletions tests/integration/client/feedback/dataset/local/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,34 @@ def test_create_dataset_with_span_questions(argilla_user: "ServerUser") -> None:
rg_dataset = ds.push_to_argilla(name="new_dataset")

assert rg_dataset.id
assert rg_dataset.questions[0].name == "spans"
assert rg_dataset.questions[0].field == "text"
assert rg_dataset.questions[0].labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")]
question = rg_dataset.questions[0]
assert question.name == "spans"
assert question.field == "text"
assert question.labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")]
assert question.allow_overlapping is False


@pytest.mark.parametrize("allow_overlapping", [True, False])
def test_create_dataset_with_span_questions_allow_overlapping(
argilla_user: "ServerUser", allow_overlapping: bool
) -> None:
argilla.client.singleton.init(api_key=argilla_user.api_key)

ds = FeedbackDataset(
fields=[TextField(name="text")],
questions=[
SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=allow_overlapping)
],
)

rg_dataset = ds.push_to_argilla(name="new_dataset")

assert rg_dataset.id
question = rg_dataset.questions[0]
assert question.name == "spans"
assert question.field == "text"
assert question.labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")]
assert question.allow_overlapping is allow_overlapping


@pytest.mark.asyncio
Expand Down Expand Up @@ -277,6 +302,71 @@ def test_add_records_with_wrong_spans_suggestions(
)


def test_add_records_with_overlapped_spans(argilla_user: "ServerUser") -> None:
argilla.client.singleton.init(api_key=argilla_user.api_key)

dataset_cfg = FeedbackDataset(
fields=[TextField(name="text")],
questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=True)],
)

dataset = dataset_cfg.push_to_argilla(name="test-dataset")
question = dataset.question_by_name("spans")

dataset.add_records(
[
FeedbackRecord(
fields={"text": "this is a text"},
suggestions=[
question.suggestion(
value=[
SpanValueSchema(start=0, end=4, label="label1"),
SpanValueSchema(start=1, end=2, label="label2"),
]
)
],
)
]
)

assert len(dataset.records) == 1

record = dataset.records[0]
assert record.suggestions[0].value == [
SpanValueSchema(start=0, end=4, label="label1"),
SpanValueSchema(start=1, end=2, label="label2"),
]


def test_add_records_with_overlapped_spans_and_disabling_overlapping_span(argilla_user: "ServerUser") -> None:
argilla.client.singleton.init(api_key=argilla_user.api_key)

dataset_cfg = FeedbackDataset(
fields=[TextField(name="text")],
questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=False)],
)

dataset = dataset_cfg.push_to_argilla(name="test-dataset")
question = dataset.question_by_name("spans")

with pytest.raises(ValidationApiError, match="overlapping values found between spans at index idx=0 and idx=1"):
dataset.add_records(
[
FeedbackRecord(
fields={"text": "this is a text"},
suggestions=[
question.suggestion(
value=[
SpanValueSchema(start=0, end=4, label="label1"),
SpanValueSchema(start=1, end=2, label="label2"),
]
)
],
)
]
)


def test_add_records_with_vectors() -> None:
dataset = FeedbackDataset(
fields=[TextField(name="text", required=True)],
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/client/feedback/schemas/remote/test_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def test_span_questions_from_api():
"type": "span",
"field": "field",
"visible_options": None,
"allow_overlapping": False,
"options": [
{"text": "Span label a", "value": "a", "description": None},
{
Expand Down Expand Up @@ -490,6 +491,7 @@ def test_span_questions_from_api_with_visible_labels():
"type": "span",
"field": "field",
"visible_options": 3,
"allow_overlapping": False,
"options": [
{"text": "Span label a", "value": "a", "description": None},
{"text": "Span label b", "value": "b", "description": None},
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/client/feedback/schemas/test_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def test_span_question() -> None:
title="Question",
description="Description",
required=True,
allow_overlapping=True,
labels=["a", "b"],
)

Expand All @@ -463,6 +464,7 @@ def test_span_question() -> None:
"type": "span",
"field": "field",
"visible_options": None,
"allow_overlapping": True,
"options": [{"value": "a", "text": "a", "description": None}, {"value": "b", "text": "b", "description": None}],
}

Expand All @@ -481,6 +483,7 @@ def test_span_question_with_labels_dict() -> None:
"type": "span",
"field": "field",
"visible_options": None,
"allow_overlapping": False,
"options": [
{"value": "a", "text": "A text", "description": None},
{"value": "b", "text": "B text", "description": None},
Expand All @@ -503,6 +506,7 @@ def test_span_question_with_visible_labels() -> None:
"type": "span",
"field": "field",
"visible_options": 3,
"allow_overlapping": False,
"options": [
{"value": "a", "text": "a", "description": None},
{"value": "b", "text": "b", "description": None},
Expand Down
Loading