Skip to content

Commit

Permalink
Merge branch 'develop' into feature/search-engine-delete-suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo committed Nov 28, 2023
2 parents e81900f + 968131c commit 8721334
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/🆕-feature-request.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name: "\U0001F195 Feature request"
about: Cool new ideas for the project
title: "[FEATURE]"
labels: enhancement
labels: ''
assignees: ''

---
Expand Down
3 changes: 2 additions & 1 deletion .github/ISSUE_TEMPLATE/🐞-bug-ui-ux.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
name: "\U0001F41E Bug report: UI/UX "
about: UI or UX bugs and unexpected behavior
title: "[BUG-UI/UX]"
labels: "bug \U0001FAB2, ui, ux"
labels: ''
assignees: ''

---

<!-- Ask David for help you to contribute https://calendly.com/argilla-office-hours/30min or feel free to submit a pull request straight away: https://github.com/argilla-io/argilla/pulls or -->
Expand Down
3 changes: 2 additions & 1 deletion .github/ISSUE_TEMPLATE/📚-add-a-documentation-report.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
name: "\U0001F4DA Add a documentation report"
about: Have you spotted a typo or mistake in our docs?
title: "[DOCS]"
labels: documentation
labels: ''
assignees: ''

---

<!-- Ask David for help you to contribute https://calendly.com/argilla-office-hours/30min or feel free to submit a pull request straight away: https://github.com/argilla-io/argilla/pulls or -->
Expand Down
3 changes: 2 additions & 1 deletion .github/ISSUE_TEMPLATE/🪲-bug-python-deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
name: "\U0001FAB2 Bug report: Python/Deployment "
about: Python or Deployment bugs and unexpected behavior
title: "[BUG-python/deployment]"
labels: "api, bug \U0001FAB2"
labels: ''
assignees: ''

---

<!-- Ask David for help you to contribute https://calendly.com/argilla-office-hours/30min or feel free to submit a pull request straight away: https://github.com/argilla-io/argilla/pulls or -->
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ These are the section headers that we use:
- Fixed error in `ArgillaTrainer`, now we can train for `extractive_question_answering` using a validation sample ([#4204](https://github.com/argilla-io/argilla/pull/4204))
- Fixed error in `ArgillaTrainer`, when training for `sentence-similarity` it didn't work with a list of values per record ([#4211](https://github.com/argilla-io/argilla/pull/4211))
- Fixed error in the unification strategy for `RankingQuestion` ([#4295](https://github.com/argilla-io/argilla/pull/4295))
- Fixed `TextClassificationSettings.labels_schema` order was not being preserved. Closes [#3828](https://github.com/argilla-io/argilla/issues/3828) ([#4332](https://github.com/argilla-io/argilla/pull/4332))
- Fixed error when requesting non-existing API endpoints. Closes [#4073](https://github.com/argilla-io/argilla/issues/4073) ([#4325](https://github.com/argilla-io/argilla/pull/4325))

### Changed
Expand Down
13 changes: 10 additions & 3 deletions src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,26 @@ class LabelsSchemaSettings(_AbstractSettings):
"""

label_schema: Set[str]
label_schema: List[str]

def __post_init__(self):
if not isinstance(self.label_schema, (set, list, tuple)):
raise ValueError(
f"`label_schema` is of type={type(self.label_schema)}, but type=set is preferred, and also both type=list and type=tuple are allowed."
)
self.label_schema = set([str(label) for label in self.label_schema])
self.label_schema = self._get_unique_labels()

def _get_unique_labels(self) -> List[str]:
unique_labels = []
for label in self.label_schema:
if label not in unique_labels:
unique_labels.append(label)
return unique_labels

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LabelsSchemaSettings":
label_schema = data.get("label_schema", {})
labels = {label["name"] for label in label_schema.get("labels", [])}
labels = label_schema.get("labels", [])
return cls(label_schema=labels)

@property
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/utils/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def assign_workspaces(
except:
pass

wk_assignments[workspace_name] = [User.from_id(user).username for user in workspace.users]
wk_assignments[workspace_name] = [User.from_id(user.id).username for user in workspace.users]

continue

Expand All @@ -257,6 +257,6 @@ def assign_workspaces(
except:
pass

wk_assignments[workspace_name] = [User.from_id(user).username for user in workspace.users]
wk_assignments[workspace_name] = [User.from_id(user.id).username for user in workspace.users]

return wk_assignments
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def get_record_suggestions(
async def upsert_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
record_id: UUID,
suggestion_create: SuggestionCreate,
current_user: User = Security(auth.get_current_user),
Expand All @@ -161,7 +162,7 @@ async def upsert_suggestion(
# TODO: We should split API v1 into different FastAPI apps so we can customize error management.
# After mapping ValueError to 422 errors for API v1 then we can remove this try except.
try:
return await datasets.upsert_suggestion(db, record, question, suggestion_create)
return await datasets.upsert_suggestion(db, search_engine, record, question, suggestion_create)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))

Expand Down
35 changes: 29 additions & 6 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,16 +1092,39 @@ async def get_suggestion_by_record_id_and_question_id(
return result.scalar_one_or_none()


async def _preload_suggestion_relationships_before_index(db: "AsyncSession", suggestion: Suggestion) -> None:
await db.execute(
select(Suggestion)
.filter_by(id=suggestion.id)
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)


async def upsert_suggestion(
db: "AsyncSession", record: Record, question: Question, suggestion_create: "SuggestionCreate"
db: "AsyncSession",
search_engine: SearchEngine,
record: Record,
question: Question,
suggestion_create: "SuggestionCreate",
) -> Suggestion:
question.parsed_settings.check_response(suggestion_create)

return await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
)
async with db.begin_nested():
suggestion = await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
autocommit=False,
)
await _preload_suggestion_relationships_before_index(db, suggestion)
await search_engine.update_record_suggestion(suggestion)

await db.commit()

return suggestion


async def delete_suggestions(
Expand Down
4 changes: 4 additions & 0 deletions src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ async def update_record_response(self, response: Response):
async def delete_record_response(self, response: Response):
pass

@abstractmethod
async def update_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def delete_record_suggestion(self, suggestion: Suggestion):
pass
Expand Down
11 changes: 11 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ async def delete_record_response(self, response: Response):
index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'}
)

async def update_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

es_suggestions = self._map_record_suggestions_to_es([suggestion])

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"doc": {"suggestions": es_suggestions}},
)

async def delete_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

Expand Down
13 changes: 13 additions & 0 deletions tests/unit/client/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
39 changes: 39 additions & 0 deletions tests/unit/client/apis/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.client.apis.datasets import TextClassificationSettings


def test_text_classification_settings_preserve_labels_order() -> None:
settings = TextClassificationSettings(
label_schema=[
"1 (extremely positive/supportive)",
"2 (positive/supportive)",
"3 (neutral)",
"4 (hateful/unsupportive)",
"5 (extremely hateful/unsupportive)",
"6 (can't say!)",
"6 (can't say!)",
"6 (can't say!)",
]
)

assert settings.label_schema == [
"1 (extremely positive/supportive)",
"2 (positive/supportive)",
"3 (neutral)",
"4 (hateful/unsupportive)",
"5 (extremely hateful/unsupportive)",
"6 (can't say!)",
]
11 changes: 9 additions & 2 deletions tests/unit/client/feedback/utils/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@ def _factory(*args, **kwargs):
mock = Mock(spec=Workspace)
mock.users = []

def create_mock_user(user_id):
user_mock = Mock()
user_mock.id = user_id
return user_mock

def add_user(user_id):
if user_id not in mock.users:
mock.users.append(user_id)
# Check if a user with this ID already exists in the list
if not any(user.id == user_id for user in mock.users):
mock_user = create_mock_user(user_id)
mock.users.append(mock_user)

mock.add_user.side_effect = add_user

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/server/api/v1/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ async def test_create_record_suggestion(
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

async def test_create_record_suggestion_update(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, owner_auth_header: dict
):
dataset = await DatasetFactory.create()
question = await TextQuestionFactory.create(dataset=dataset)
Expand All @@ -1240,6 +1240,8 @@ async def test_create_record_suggestion_update(
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

mock_search_engine.update_record_suggestion.assert_called_once_with(suggestion)

@pytest.mark.parametrize(
"payload",
[
Expand Down

0 comments on commit 8721334

Please sign in to comment.