From 176f953e00f95ed15251b707cf2cb9794d712006 Mon Sep 17 00:00:00 2001 From: Zac Li Date: Fri, 8 Mar 2024 16:01:04 +0800 Subject: [PATCH] feat: support provider endpoint in jina executor (#6149) Co-authored-by: Joan Martinez Co-authored-by: Jina Dev Bot --- jina/orchestrate/deployments/__init__.py | 2 + jina/orchestrate/flow/base.py | 9 ++ jina/parsers/orchestrate/pod.py | 7 ++ jina/serve/executors/__init__.py | 22 +++-- .../serve/runtimes/worker/request_handling.py | 1 + jina_cli/autocomplete.py | 4 + .../sagemaker/SampleColbertExecutor/README.md | 2 + .../SampleColbertExecutor/config.yml | 8 ++ .../SampleColbertExecutor/executor.py | 72 ++++++++++++++++ .../SampleColbertExecutor/requirements.txt | 0 .../docarray_v2/sagemaker/test_colbert.py | 86 +++++++++++++++++++ 11 files changed, 207 insertions(+), 6 deletions(-) create mode 100644 tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/README.md create mode 100644 tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/config.yml create mode 100644 tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/executor.py create mode 100644 tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/requirements.txt create mode 100644 tests/integration/docarray_v2/sagemaker/test_colbert.py diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 7bbfa82468abb..d2e89476a82b6 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -287,6 +287,7 @@ def __init__( prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], provider: Optional[str] = ['NONE'], + provider_endpoint: Optional[str] = None, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -387,6 +388,7 @@ def __init__( :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your diff --git a/jina/orchestrate/flow/base.py b/jina/orchestrate/flow/base.py index 5c3622224ff78..4c8bb9b777191 100644 --- a/jina/orchestrate/flow/base.py +++ b/jina/orchestrate/flow/base.py @@ -203,6 +203,7 @@ def __init__( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = ['GRPC'], provider: Optional[str] = ['NONE'], + provider_endpoint: Optional[str] = None, proxy: Optional[bool] = False, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, @@ -274,6 +275,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -465,6 +467,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -872,6 +875,7 @@ def add( prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], provider: Optional[str] = ['NONE'], + provider_endpoint: Optional[str] = None, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -972,6 +976,7 @@ def add( :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1135,6 +1140,7 @@ def add( :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1330,6 +1336,7 @@ def config_gateway( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = ['GRPC'], provider: Optional[str] = ['NONE'], + provider_endpoint: Optional[str] = None, proxy: Optional[bool] = False, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, @@ -1401,6 +1408,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -1501,6 +1509,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway diff --git a/jina/parsers/orchestrate/pod.py b/jina/parsers/orchestrate/pod.py index fe908bfa463de..0780d13dd81f0 100644 --- a/jina/parsers/orchestrate/pod.py +++ b/jina/parsers/orchestrate/pod.py @@ -217,6 +217,13 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'): help=f'If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: {[provider.to_string() for provider in list(ProviderType)]}.', ) + arg_group.add_argument( + '--provider-endpoint', + type=str, + default=None, + help=f'If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider.', + ) + arg_group.add_argument( '--monitoring', action='store_true', diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index c5c8f72a8e6c1..b33099f61e90e 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -622,17 +622,25 @@ def _validate_sagemaker(self): if '/invocations' in self.requests: return + if ( + hasattr(self.runtime_args, 'provider_endpoint') + and self.runtime_args.provider_endpoint + ): + endpoint_to_use = ('/' + self.runtime_args.provider_endpoint).lower() + if endpoint_to_use in list(self.requests.keys()): + self.logger.warning( + f'Using "{endpoint_to_use}" as "/invocations" route' + ) + self.requests['/invocations'] = self.requests[endpoint_to_use] + return + if len(self.requests) == 1: route = list(self.requests.keys())[0] - self.logger.warning( - f'No "/invocations" route found. Using "{route}" as "/invocations" route' - ) + self.logger.warning(f'Using "{route}" as "/invocations" route') self.requests['/invocations'] = self.requests[route] return - raise ValueError( - 'No "/invocations" route found. Please define a "/invocations" route' - ) + raise ValueError('Cannot identify the endpoint to use for "/invocations"') def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): if _dynamic_batching: @@ -994,6 +1002,7 @@ def serve( prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], provider: Optional[str] = ['NONE'], + provider_endpoint: Optional[str] = None, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -1094,6 +1103,7 @@ def serve( :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 2e095cb26da50..a145121e3eec4 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -386,6 +386,7 @@ def _load_executor( 'replicas': self.args.replicas, 'name': self.args.name, 'provider': self.args.provider, + 'provider_endpoint': self.args.provider_endpoint, 'metrics_registry': metrics_registry, 'tracer_provider': tracer_provider, 'meter_provider': meter_provider, diff --git a/jina_cli/autocomplete.py b/jina_cli/autocomplete.py index 4fda9f54eaa66..e3f85ff9fc5d3 100644 --- a/jina_cli/autocomplete.py +++ b/jina_cli/autocomplete.py @@ -72,6 +72,7 @@ '--protocol', '--protocols', '--provider', + '--provider-endpoint', '--monitoring', '--port-monitoring', '--retries', @@ -180,6 +181,7 @@ '--protocol', '--protocols', '--provider', + '--provider-endpoint', '--monitoring', '--port-monitoring', '--retries', @@ -438,6 +440,7 @@ '--protocol', '--protocols', '--provider', + '--provider-endpoint', '--monitoring', '--port-monitoring', '--retries', @@ -512,6 +515,7 @@ '--protocol', '--protocols', '--provider', + '--provider-endpoint', '--monitoring', '--port-monitoring', '--retries', diff --git a/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/README.md b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/README.md new file mode 100644 index 0000000000000..98a2b6384793f --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/README.md @@ -0,0 +1,2 @@ +# SampleColbertExecutor + diff --git a/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/config.yml b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/config.yml new file mode 100644 index 0000000000000..7c43d5d456ed0 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/config.yml @@ -0,0 +1,8 @@ +jtype: SampleColbertExecutor +py_modules: + - executor.py +metas: + name: SampleColbertExecutor + description: + url: + keywords: [] diff --git a/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/executor.py b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/executor.py new file mode 100644 index 0000000000000..0fba5c6c04e9b --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/executor.py @@ -0,0 +1,72 @@ +import numpy as np +from docarray import BaseDoc, DocList +from docarray.typing import NdArray +from pydantic import Field +from typing import Union, Optional, List +from jina import Executor, requests + + +class TextDoc(BaseDoc): + text: str = Field(description="The text of the document", default="") + + +class RerankerInput(BaseDoc): + query: Union[str, TextDoc] + + documents: List[TextDoc] + + top_n: Optional[int] + + +class RankedObjectOutput(BaseDoc): + index: int + document: Optional[TextDoc] + + relevance_score: float + + +class EmbeddingResponseModel(TextDoc): + embeddings: NdArray + + +class RankedOutput(BaseDoc): + results: DocList[RankedObjectOutput] + + +class SampleColbertExecutor(Executor): + @requests(on="/rank") + def foo(self, docs: DocList[RerankerInput], **kwargs) -> DocList[RankedOutput]: + ret = [] + for doc in docs: + ret.append( + RankedOutput( + results=[ + RankedObjectOutput( + id=doc.id, + index=0, + document=TextDoc(text="first result"), + relevance_score=-1, + ), + RankedObjectOutput( + id=doc.id, + index=1, + document=TextDoc(text="second result"), + relevance_score=-2, + ), + ] + ) + ) + return DocList[RankedOutput](ret) + + @requests(on="/encode") + def bar(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: + ret = [] + for doc in docs: + ret.append( + EmbeddingResponseModel( + id=doc.id, + text=doc.text, + embeddings=np.random.random((1, 64)), + ) + ) + return DocList[EmbeddingResponseModel](ret) diff --git a/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/requirements.txt b/tests/integration/docarray_v2/sagemaker/SampleColbertExecutor/requirements.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/sagemaker/test_colbert.py b/tests/integration/docarray_v2/sagemaker/test_colbert.py new file mode 100644 index 0000000000000..bc4988c6efed0 --- /dev/null +++ b/tests/integration/docarray_v2/sagemaker/test_colbert.py @@ -0,0 +1,86 @@ +import csv +import io +import os + +import requests +from jina.orchestrate.pods import Pod +from jina.parsers import set_pod_parser + +sagemaker_port = 8080 + + +def test_provider_sagemaker_pod_rank(): + args, _ = set_pod_parser().parse_known_args( + [ + "--uses", + os.path.join( + os.path.dirname(__file__), "SampleColbertExecutor", "config.yml" + ), + "--provider", + "sagemaker", + "--provider-endpoint", + "rank", + "serve", # This is added by sagemaker + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + resp = requests.get(f"http://localhost:{sagemaker_port}/ping") + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f"http://localhost:{sagemaker_port}/invocations", + json={ + "data": { + "documents": [ + {"text": "the dog is in the house"}, + {"text": "hey Peter"}, + ], + "query": "where is the dog", + "top_n": 2, + } + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["data"]) == 1 + assert resp_json["data"][0]["results"][0]["document"]["text"] == "first result" + + +def test_provider_sagemaker_pod_encode(): + args, _ = set_pod_parser().parse_known_args( + [ + "--uses", + os.path.join( + os.path.dirname(__file__), "SampleColbertExecutor", "config.yml" + ), + "--provider", + "sagemaker", + "--provider-endpoint", + "encode", + "serve", # This is added by sagemaker + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + resp = requests.get(f"http://localhost:{sagemaker_port}/ping") + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f"http://localhost:{sagemaker_port}/invocations", + json={ + "data": [ + {"text": "hello world"}, + ] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["data"]) == 1 + assert len(resp_json["data"][0]["embeddings"][0]) == 64