-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support provider endpoint in jina executor (#6149)
Co-authored-by: Joan Martinez <[email protected]> Co-authored-by: Jina Dev Bot <[email protected]>
- Loading branch information
1 parent
69bf8f6
commit 176f953
Showing
11 changed files
with
207 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 2 additions & 0 deletions
2
tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# SampleColbertExecutor | ||
|
8 changes: 8 additions & 0 deletions
8
tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/config.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
jtype: SampleColbertExecutor | ||
py_modules: | ||
- executor.py | ||
metas: | ||
name: SampleColbertExecutor | ||
description: | ||
url: | ||
keywords: [] |
72 changes: 72 additions & 0 deletions
72
tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/executor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
from docarray import BaseDoc, DocList | ||
from docarray.typing import NdArray | ||
from pydantic import Field | ||
from typing import Union, Optional, List | ||
from jina import Executor, requests | ||
|
||
|
||
class TextDoc(BaseDoc): | ||
text: str = Field(description="The text of the document", default="") | ||
|
||
|
||
class RerankerInput(BaseDoc): | ||
query: Union[str, TextDoc] | ||
|
||
documents: List[TextDoc] | ||
|
||
top_n: Optional[int] | ||
|
||
|
||
class RankedObjectOutput(BaseDoc): | ||
index: int | ||
document: Optional[TextDoc] | ||
|
||
relevance_score: float | ||
|
||
|
||
class EmbeddingResponseModel(TextDoc): | ||
embeddings: NdArray | ||
|
||
|
||
class RankedOutput(BaseDoc): | ||
results: DocList[RankedObjectOutput] | ||
|
||
|
||
class SampleColbertExecutor(Executor): | ||
@requests(on="/rank") | ||
def foo(self, docs: DocList[RerankerInput], **kwargs) -> DocList[RankedOutput]: | ||
ret = [] | ||
for doc in docs: | ||
ret.append( | ||
RankedOutput( | ||
results=[ | ||
RankedObjectOutput( | ||
id=doc.id, | ||
index=0, | ||
document=TextDoc(text="first result"), | ||
relevance_score=-1, | ||
), | ||
RankedObjectOutput( | ||
id=doc.id, | ||
index=1, | ||
document=TextDoc(text="second result"), | ||
relevance_score=-2, | ||
), | ||
] | ||
) | ||
) | ||
return DocList[RankedOutput](ret) | ||
|
||
@requests(on="/encode") | ||
def bar(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: | ||
ret = [] | ||
for doc in docs: | ||
ret.append( | ||
EmbeddingResponseModel( | ||
id=doc.id, | ||
text=doc.text, | ||
embeddings=np.random.random((1, 64)), | ||
) | ||
) | ||
return DocList[EmbeddingResponseModel](ret) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import csv | ||
import io | ||
import os | ||
|
||
import requests | ||
from jina.orchestrate.pods import Pod | ||
from jina.parsers import set_pod_parser | ||
|
||
sagemaker_port = 8080 | ||
|
||
|
||
def test_provider_sagemaker_pod_rank(): | ||
args, _ = set_pod_parser().parse_known_args( | ||
[ | ||
"--uses", | ||
os.path.join( | ||
os.path.dirname(__file__), "SampleColbertExecutor", "config.yml" | ||
), | ||
"--provider", | ||
"sagemaker", | ||
"--provider-endpoint", | ||
"rank", | ||
"serve", # This is added by sagemaker | ||
] | ||
) | ||
with Pod(args): | ||
# Test the `GET /ping` endpoint (added by jina for sagemaker) | ||
resp = requests.get(f"http://localhost:{sagemaker_port}/ping") | ||
assert resp.status_code == 200 | ||
assert resp.json() == {} | ||
|
||
# Test the `POST /invocations` endpoint for inference | ||
# Note: this endpoint is not implemented in the sample executor | ||
resp = requests.post( | ||
f"http://localhost:{sagemaker_port}/invocations", | ||
json={ | ||
"data": { | ||
"documents": [ | ||
{"text": "the dog is in the house"}, | ||
{"text": "hey Peter"}, | ||
], | ||
"query": "where is the dog", | ||
"top_n": 2, | ||
} | ||
}, | ||
) | ||
assert resp.status_code == 200 | ||
resp_json = resp.json() | ||
assert len(resp_json["data"]) == 1 | ||
assert resp_json["data"][0]["results"][0]["document"]["text"] == "first result" | ||
|
||
|
||
def test_provider_sagemaker_pod_encode(): | ||
args, _ = set_pod_parser().parse_known_args( | ||
[ | ||
"--uses", | ||
os.path.join( | ||
os.path.dirname(__file__), "SampleColbertExecutor", "config.yml" | ||
), | ||
"--provider", | ||
"sagemaker", | ||
"--provider-endpoint", | ||
"encode", | ||
"serve", # This is added by sagemaker | ||
] | ||
) | ||
with Pod(args): | ||
# Test the `GET /ping` endpoint (added by jina for sagemaker) | ||
resp = requests.get(f"http://localhost:{sagemaker_port}/ping") | ||
assert resp.status_code == 200 | ||
assert resp.json() == {} | ||
|
||
# Test the `POST /invocations` endpoint for inference | ||
# Note: this endpoint is not implemented in the sample executor | ||
resp = requests.post( | ||
f"http://localhost:{sagemaker_port}/invocations", | ||
json={ | ||
"data": [ | ||
{"text": "hello world"}, | ||
] | ||
}, | ||
) | ||
assert resp.status_code == 200 | ||
resp_json = resp.json() | ||
assert len(resp_json["data"]) == 1 | ||
assert len(resp_json["data"][0]["embeddings"][0]) == 64 |