Skip to content

Commit

Permalink
chore: refactor CompositeServer for better extendability (#5978)
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
JoanFM authored Jul 25, 2023
1 parent 99c964e commit 8141d99
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 29 deletions.
6 changes: 3 additions & 3 deletions jina/serve/runtimes/servers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,21 @@ def port(self):
"""Gets the first port of the port list argument. To be used in the regular case where a Gateway exposes a single port
:return: The first port to be exposed
"""
return self.runtime_args.port[0]
return self.runtime_args.port[0] if isinstance(self.runtime_args.port, list) else self.runtime_args.port

@property
def ports(self):
"""Gets all the list of ports from the runtime_args as a list.
:return: The lists of ports to be exposed
"""
return self.runtime_args.port
return self.runtime_args.port if isinstance(self.runtime_args.port, list) else [self.runtime_args.port]

@property
def protocols(self):
"""Gets all the list of protocols from the runtime_args as a list.
:return: The lists of protocols to be exposed
"""
return self.runtime_args.protocol
return self.runtime_args.protocol if isinstance(self.runtime_args.protocol, list) else [self.runtime_args.protocol]

@property
def host(self):
Expand Down
80 changes: 54 additions & 26 deletions jina/serve/runtimes/servers/composite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import asyncio
import copy
from typing import Any, List
from typing import Any, List, TYPE_CHECKING

from jina.serve.runtimes.servers import BaseServer

if TYPE_CHECKING:
from jina.logging.logger import JinaLogger

class CompositeServer(BaseServer):
"""Composite Server implementation"""

class CompositeBaseServer(BaseServer):
"""Composite Base Server implementation from which u can inherit a specific custom composite one"""
servers: List['BaseServer']
logger: 'JinaLogger'

def __init__(
self,
Expand All @@ -16,30 +21,32 @@ def __init__(
:param kwargs: keyword args
"""
super().__init__(**kwargs)
from jina.parsers.helper import _get_gateway_class
self._kwargs = kwargs

self.servers: List[BaseServer] = []
@property
def _server_kwargs(self):
ret = []
# ignore monitoring and tracing args since they are not copyable
ignored_attrs = [
'metrics_registry',
'tracer_provider',
'grpc_tracing_server_interceptors',
'aio_tracing_client_interceptors',
'tracing_client_interceptor',
]
for port, protocol in zip(self.ports, self.protocols):
server_cls = _get_gateway_class(protocol, works_as_load_balancer=self.works_as_load_balancer)
# ignore monitoring and tracing args since they are not copyable
ignored_attrs = [
'metrics_registry',
'tracer_provider',
'grpc_tracing_server_interceptors',
'aio_tracing_client_interceptors',
'tracing_client_interceptor',
]
runtime_args = self._deepcopy_with_ignore_attrs(
self.runtime_args, ignored_attrs
)
runtime_args.port = [port]
runtime_args.protocol = [protocol]
server_kwargs = {k: v for k, v in kwargs.items() if k != 'runtime_args'}
runtime_args.port = port
runtime_args.protocol = protocol
server_kwargs = {k: v for k, v in self._kwargs.items() if k != 'runtime_args'}
server_kwargs['runtime_args'] = dict(vars(runtime_args))
server_kwargs['req_handler'] = self._request_handler
server = server_cls(**server_kwargs)
self.servers.append(server)
self.gateways = self.servers # for backwards compatibility
ret.append(server_kwargs)

return ret

async def setup_server(self):
"""
Expand Down Expand Up @@ -72,6 +79,34 @@ async def run_server(self):

await asyncio.gather(*run_server_tasks)

@property
def _should_exit(self) -> bool:
should_exit_values = [
getattr(server, 'should_exit', True) for server in self.servers
]
return all(should_exit_values)


class CompositeServer(CompositeBaseServer):
"""Composite Server implementation"""

def __init__(
self,
**kwargs,
):
"""Initialize the gateway
:param kwargs: keyword args
"""
super().__init__(**kwargs)
from jina.parsers.helper import _get_gateway_class

self.servers: List[BaseServer] = []
for server_kwargs in self._server_kwargs:
server_cls = _get_gateway_class(server_kwargs['runtime_args']['protocol'], works_as_load_balancer=self.works_as_load_balancer)
server = server_cls(**server_kwargs)
self.servers.append(server)
self.gateways = self.servers # for backwards compatibility

@staticmethod
def _deepcopy_with_ignore_attrs(obj: Any, ignore_attrs: List[str]) -> Any:
"""Deep copy an object and ignore some attributes
Expand All @@ -87,10 +122,3 @@ def _deepcopy_with_ignore_attrs(obj: Any, ignore_attrs: List[str]) -> Any:
memo[id(getattr(obj, k))] = None # getattr(obj, k)

return copy.deepcopy(obj, memo)

@property
def _should_exit(self) -> bool:
should_exit_values = [
getattr(server, 'should_exit', True) for server in self.servers
]
return all(should_exit_values)

0 comments on commit 8141d99

Please sign in to comment.