Skip to content

Commit

Permalink
fix: batch transform update for sagemaker reranker integration (#6145)
Browse files Browse the repository at this point in the history
Co-authored-by: Joan Martinez <[email protected]>
  • Loading branch information
zac-li and JoanFM authored Feb 29, 2024
1 parent f008ab5 commit 8a58dfa
Show file tree
Hide file tree
Showing 12 changed files with 447 additions and 247 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ jobs:
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/docker
echo "flag it as jina for codeoverage"
echo "codecov_flag=jina" >> $GITHUB_OUTPUT
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ jobs:
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/docker
echo "flag it as jina for codeoverage"
echo "codecov_flag=jina" >> $GITHUB_OUTPUT
Expand Down
5 changes: 4 additions & 1 deletion jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def is_pydantic_model(annotation: Type) -> bool:
:param annotation: The annotation from which to extract PydantiModel.
:return: boolean indicating if a Pydantic model is inside the annotation
"""
from typing import get_args, get_origin
try:
from typing import get_args, get_origin
except ImportError:
from typing_extensions import get_args, get_origin

from pydantic import BaseModel

Expand Down
51 changes: 50 additions & 1 deletion jina/serve/runtimes/worker/http_sagemaker_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ def add_post_route(
input_doc_list_model=None,
output_doc_list_model=None,
):
import json
from typing import List, Type, Union
try:
from typing import get_args, get_origin
except ImportError:
from typing_extensions import get_args, get_origin

from docarray.base_doc.docarray_response import DocArrayResponse
from pydantic import BaseModel, ValidationError, parse_obj_as

app_kwargs = dict(
path=f'/{endpoint_path.strip("/")}',
Expand Down Expand Up @@ -145,6 +153,47 @@ async def post(request: Request):
detail='Invalid CSV input. Please check your input.',
)

def construct_model_from_line(
model: Type[BaseModel], line: List[str]
) -> BaseModel:
parsed_fields = {}
model_fields = model.__fields__

for field_str, (field_name, field_info) in zip(
line, model_fields.items()
):
field_type = field_info.outer_type_

# Handle Union types by attempting to arse each potential type
if get_origin(field_type) is Union:
for possible_type in get_args(field_type):
if possible_type is str:
parsed_fields[field_name] = field_str
break
else:
try:
parsed_fields[field_name] = parse_obj_as(
possible_type, json.loads(field_str)
)
break
except (json.JSONDecodeError, ValidationError):
continue
# Handle list of nested models
elif get_origin(field_type) is list:
list_item_type = get_args(field_type)[0]
parsed_list = json.loads(field_str)
if issubclass(list_item_type, BaseModel):
parsed_fields[field_name] = parse_obj_as(
List[list_item_type], parsed_list
)
else:
parsed_fields[field_name] = parsed_list
# Handle direct assignment for basic types
else:
parsed_fields[field_name] = field_info.type_(field_str)

return model(**parsed_fields)

# NOTE: Sagemaker only supports csv files without header, so we enforce
# the header by getting the field names from the input model.
# This will also enforce the order of the fields in the csv file.
Expand All @@ -165,7 +214,7 @@ async def post(request: Request):
detail=f'Invalid CSV format. Line {line} doesn\'t match '
f'the expected field order {field_names}.',
)
data.append(input_doc_list_model(**dict(zip(field_names, line))))
data.append(construct_model_from_line(input_doc_list_model, line))

return await process(input_model(data=data))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SampleRerankerExecutor

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
jtype: SampleRerankerExecutor
py_modules:
- executor.py
metas:
name: SampleRerankerExecutor
description:
url:
keywords: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from docarray import BaseDoc, DocList
from pydantic import Field
from typing import Union, Optional, List
from jina import Executor, requests


class TextDoc(BaseDoc):
text: str = Field(description="The text of the document", default="")


class RerankerInput(BaseDoc):
query: Union[str, TextDoc]

documents: List[TextDoc]

top_n: Optional[int]


class RankedObjectOutput(BaseDoc):
index: int
document: Optional[TextDoc]

relevance_score: float


class RankedOutput(BaseDoc):
results: DocList[RankedObjectOutput]


class SampleRerankerExecutor(Executor):
@requests(on="/rerank")
def foo(self, docs: DocList[RerankerInput], **kwargs) -> DocList[RankedOutput]:
ret = []
for doc in docs:
ret.append(
RankedOutput(
results=[
RankedObjectOutput(
id=doc.id,
index=0,
document=TextDoc(text="first result"),
relevance_score=-1,
),
RankedObjectOutput(
id=doc.id,
index=1,
document=TextDoc(text="second result"),
relevance_score=-2,
),
]
)
)
return DocList[RankedOutput](ret)
Empty file.
Loading

0 comments on commit 8a58dfa

Please sign in to comment.