From abc4ca24f4a36347018370050d3d5cd42190ee93 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Thu, 5 Sep 2024 11:43:06 +0200 Subject: [PATCH] refactor: slight change in dyn batch queue (#6193) --- jina/serve/runtimes/worker/batch_queue.py | 193 +++++++++--------- .../serve/runtimes/worker/request_handling.py | 1 + .../dynamic_batching/test_dynamic_batching.py | 27 ++- .../dynamic_batching/test_batch_queue.py | 48 +++-- 4 files changed, 154 insertions(+), 115 deletions(-) diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 9368e81b7d8fd..0419e35414a46 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -3,7 +3,7 @@ from asyncio import Event, Task from typing import Callable, Dict, List, Optional, TYPE_CHECKING from jina._docarray import docarray_v2 - +import contextlib if not docarray_v2: from docarray import DocumentArray else: @@ -24,11 +24,16 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, + allow_concurrent: bool = False, flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, ) -> None: - self._data_lock = asyncio.Lock() + # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent + if allow_concurrent and flush_all: + self._data_lock = contextlib.AsyncExitStack() + else: + self._data_lock = asyncio.Lock() self.func = func if params is None: params = dict() @@ -104,19 +109,20 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue: # this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc` # before the `flush` task processes it. self._start_timer() - if not self._flush_task: - self._flush_task = asyncio.create_task(self._await_then_flush(http)) - - self._big_doc.extend(docs) - next_req_idx = len(self._requests) - num_docs = len(docs) - self._request_idxs.extend([next_req_idx] * num_docs) - self._request_lens.append(len(docs)) - self._requests.append(request) - queue = asyncio.Queue() - self._requests_completed.append(queue) - if len(self._big_doc) >= self._preferred_batch_size: - self._flush_trigger.set() + async with self._data_lock: + if not self._flush_task: + self._flush_task = asyncio.create_task(self._await_then_flush(http)) + + self._big_doc.extend(docs) + next_req_idx = len(self._requests) + num_docs = len(docs) + self._request_idxs.extend([next_req_idx] * num_docs) + self._request_lens.append(len(docs)) + self._requests.append(request) + queue = asyncio.Queue() + self._requests_completed.append(queue) + if len(self._big_doc) >= self._preferred_batch_size: + self._flush_trigger.set() return queue @@ -236,74 +242,94 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): await self._flush_trigger.wait() # writes to shared data between tasks need to be mutually exclusive - big_doc_in_batch = copy.copy(self._big_doc) - requests_idxs_in_batch = copy.copy(self._request_idxs) - requests_lens_in_batch = copy.copy(self._request_lens) - requests_in_batch = copy.copy(self._requests) - requests_completed_in_batch = copy.copy(self._requests_completed) + async with self._data_lock: + big_doc_in_batch = copy.copy(self._big_doc) + requests_idxs_in_batch = copy.copy(self._request_idxs) + requests_lens_in_batch = copy.copy(self._request_lens) + requests_in_batch = copy.copy(self._requests) + requests_completed_in_batch = copy.copy(self._requests_completed) - self._reset() + self._reset() - # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in - # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to - # communicate that the request has been processed properly. + # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in + # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to + # communicate that the request has been processed properly. - if not docarray_v2: - non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() - else: - non_assigned_to_response_docs = self._response_docarray_cls() + if not docarray_v2: + non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() + else: + non_assigned_to_response_docs = self._response_docarray_cls() - non_assigned_to_response_request_idxs = [] - sum_from_previous_first_req_idx = 0 - for docs_inner_batch, req_idxs in batch( - big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None - ): - involved_requests_min_indx = req_idxs[0] - involved_requests_max_indx = req_idxs[-1] - input_len_before_call: int = len(docs_inner_batch) - batch_res_docs = None - try: - batch_res_docs = await self.func( - docs=docs_inner_batch, - parameters=self.params, - docs_matrix=None, # joining manually with batch queue is not supported right now - tracing_context=None, - ) - # Output validation - if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( - not docarray_v2 - and isinstance(batch_res_docs, DocumentArray) - ): - if not len(batch_res_docs) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' + non_assigned_to_response_request_idxs = [] + sum_from_previous_first_req_idx = 0 + for docs_inner_batch, req_idxs in batch( + big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None + ): + involved_requests_min_indx = req_idxs[0] + involved_requests_max_indx = req_idxs[-1] + input_len_before_call: int = len(docs_inner_batch) + batch_res_docs = None + try: + batch_res_docs = await self.func( + docs=docs_inner_batch, + parameters=self.params, + docs_matrix=None, # joining manually with batch queue is not supported right now + tracing_context=None, + ) + # Output validation + if (docarray_v2 and isinstance(batch_res_docs, DocList)) or ( + not docarray_v2 + and isinstance(batch_res_docs, DocumentArray) + ): + if not len(batch_res_docs) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}' + ) + elif batch_res_docs is None: + if not len(docs_inner_batch) == input_len_before_call: + raise ValueError( + f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' + ) + else: + array_name = ( + 'DocumentArray' if not docarray_v2 else 'DocList' ) - elif batch_res_docs is None: - if not len(docs_inner_batch) == input_len_before_call: - raise ValueError( - f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}' + raise TypeError( + f'The return type must be {array_name} / `None` when using dynamic batching, ' + f'but getting {batch_res_docs!r}' ) + except Exception as exc: + # All the requests containing docs in this Exception should be raising it + for request_full in requests_completed_in_batch[ + involved_requests_min_indx : involved_requests_max_indx + 1 + ]: + await request_full.put(exc) else: - array_name = ( - 'DocumentArray' if not docarray_v2 else 'DocList' + # We need to attribute the docs to their requests + non_assigned_to_response_docs.extend( + batch_res_docs or docs_inner_batch ) - raise TypeError( - f'The return type must be {array_name} / `None` when using dynamic batching, ' - f'but getting {batch_res_docs!r}' + non_assigned_to_response_request_idxs.extend(req_idxs) + num_assigned_docs = await _assign_results( + non_assigned_to_response_docs, + non_assigned_to_response_request_idxs, + sum_from_previous_first_req_idx, + requests_lens_in_batch, + requests_in_batch, + requests_completed_in_batch, ) - except Exception as exc: - # All the requests containing docs in this Exception should be raising it - for request_full in requests_completed_in_batch[ - involved_requests_min_indx : involved_requests_max_indx + 1 - ]: - await request_full.put(exc) - else: - # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend( - batch_res_docs or docs_inner_batch - ) - non_assigned_to_response_request_idxs.extend(req_idxs) - num_assigned_docs = await _assign_results( + + sum_from_previous_first_req_idx = ( + len(non_assigned_to_response_docs) - num_assigned_docs + ) + non_assigned_to_response_docs = non_assigned_to_response_docs[ + num_assigned_docs: + ] + non_assigned_to_response_request_idxs = ( + non_assigned_to_response_request_idxs[num_assigned_docs:] + ) + if len(non_assigned_to_response_request_idxs) > 0: + _ = await _assign_results( non_assigned_to_response_docs, non_assigned_to_response_request_idxs, sum_from_previous_first_req_idx, @@ -312,25 +338,6 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1): requests_completed_in_batch, ) - sum_from_previous_first_req_idx = ( - len(non_assigned_to_response_docs) - num_assigned_docs - ) - non_assigned_to_response_docs = non_assigned_to_response_docs[ - num_assigned_docs: - ] - non_assigned_to_response_request_idxs = ( - non_assigned_to_response_request_idxs[num_assigned_docs:] - ) - if len(non_assigned_to_response_request_idxs) > 0: - _ = await _assign_results( - non_assigned_to_response_docs, - non_assigned_to_response_request_idxs, - sum_from_previous_first_req_idx, - requests_lens_in_batch, - requests_in_batch, - requests_completed_in_batch, - ) - async def close(self): """Closes the batch queue by flushing pending requests.""" if not self._is_closed: diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 456c94a7bdf41..52a5070ea83e4 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -702,6 +702,7 @@ async def handle( ].response_schema, output_array_type=self.args.output_array_type, params=params, + allow_concurrent=self.args.allow_concurrent, **self._batchqueue_config[exec_endpoint], ) # This is necessary because push might need to await for the queue to be emptied diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 483f247db7892..87e98455317bb 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -218,7 +218,9 @@ def call_api_with_params(req: RequestStructParams): ], ) @pytest.mark.parametrize('use_stream', [False, True]) -def test_timeout(add_parameters, use_stream): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +def test_timeout(add_parameters, use_stream, allow_concurrent): + add_parameters['allow_concurrent'] = allow_concurrent f = Flow().add(**add_parameters) with f: start_time = time.time() @@ -265,7 +267,9 @@ def test_timeout(add_parameters, use_stream): ], ) @pytest.mark.parametrize('use_stream', [False, True]) -def test_preferred_batch_size(add_parameters, use_stream): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +def test_preferred_batch_size(add_parameters, use_stream, allow_concurrent): + add_parameters['allow_concurrent'] = allow_concurrent f = Flow().add(**add_parameters) with f: with mp.Pool(2) as p: @@ -315,8 +319,9 @@ def test_preferred_batch_size(add_parameters, use_stream): @pytest.mark.repeat(10) @pytest.mark.parametrize('use_stream', [False, True]) -def test_correctness(use_stream): - f = Flow().add(uses=PlaceholderExecutor) +@pytest.mark.parametrize('allow_concurrent', [False, True]) +def test_correctness(use_stream, allow_concurrent): + f = Flow().add(uses=PlaceholderExecutor, allow_concurrent=allow_concurrent) with f: with mp.Pool(2) as p: results = list( @@ -686,7 +691,14 @@ def foo(self, docs, **kwargs): True ], ) -async def test_num_docs_processed_in_exec(flush_all): +@pytest.mark.parametrize( + 'allow_concurrent', + [ + False, + True + ], +) +async def test_num_docs_processed_in_exec(flush_all, allow_concurrent): class DynBatchProcessor(Executor): @dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all) @@ -695,7 +707,7 @@ def foo(self, docs, **kwargs): for doc in docs: doc.text = f"{len(docs)}" - depl = Deployment(uses=DynBatchProcessor, protocol='http') + depl = Deployment(uses=DynBatchProcessor, protocol='http', allow_concurrent=allow_concurrent) with depl: da = DocumentArray([Document(text='good') for _ in range(50)]) @@ -721,5 +733,6 @@ def foo(self, docs, **kwargs): larger_than_5 += 1 if int(d.text) < 5: smaller_than_5 += 1 - assert smaller_than_5 == 1 + + assert smaller_than_5 == (1 if allow_concurrent else 0) assert larger_than_5 > 0 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 9db1958b86e05..40622b478322d 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -10,7 +10,8 @@ @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -async def test_batch_queue_timeout(flush_all): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_batch_queue_timeout(flush_all, allow_concurrent): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -22,6 +23,7 @@ async def foo(docs, **kwargs): preferred_batch_size=4, timeout=2000, flush_all=flush_all, + allow_concurrent=allow_concurrent, ) three_data_requests = [DataRequest() for _ in range(3)] @@ -62,15 +64,15 @@ async def process_request(req): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all, allow_concurrent): batches_lengths_computed = [] lock = asyncio.Lock() async def foo(docs, **kwargs): - async with lock: - await asyncio.sleep(4) - batches_lengths_computed.append(len(docs)) - return DocumentArray([Document(text='Done') for _ in docs]) + await asyncio.sleep(4) + batches_lengths_computed.append(len(docs)) + return DocumentArray([Document(text='Done') for _ in docs]) bq: BatchQueue = BatchQueue( foo, @@ -78,7 +80,8 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=5, timeout=3000, - flush_all=flush_all + flush_all=flush_all, + allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(3)] @@ -108,10 +111,13 @@ async def process_request(req, sleep=0): assert time_spent >= 12000 assert time_spent <= 12500 else: - assert time_spent >= 8000 - assert time_spent <= 8500 + if not allow_concurrent: + assert time_spent >= 8000 + assert time_spent <= 8500 + else: + assert time_spent < 8000 if flush_all is False: - assert batches_lengths_computed == [5, 2, 1] + assert batches_lengths_computed == [5, 1, 2] else: assert batches_lengths_computed == [6, 2] @@ -120,7 +126,8 @@ async def process_request(req, sleep=0): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -async def test_batch_queue_req_length_larger_than_preferred(flush_all): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_batch_queue_req_length_larger_than_preferred(flush_all, allow_concurrent): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -132,6 +139,7 @@ async def foo(docs, **kwargs): preferred_batch_size=4, timeout=2000, flush_all=flush_all, + allow_concurrent=allow_concurrent, ) data_requests = [DataRequest() for _ in range(3)] @@ -158,7 +166,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_exception(): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_exception(allow_concurrent): BAD_REQUEST_IDX = [2, 6] async def foo(docs, **kwargs): @@ -175,6 +184,8 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=1, timeout=500, + flush_all=False, + allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(35)] @@ -204,7 +215,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_exception_more_complex(): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_exception_more_complex(allow_concurrent): TRIGGER_BAD_REQUEST_IDX = [2, 6] EXPECTED_BAD_REQUESTS = [2, 3, 6, 7] @@ -225,6 +237,8 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=2, timeout=500, + flush_all=False, + allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(35)] @@ -257,7 +271,8 @@ async def process_request(req): @pytest.mark.asyncio @pytest.mark.parametrize('flush_all', [False, True]) -async def test_exception_all(flush_all): +@pytest.mark.parametrize('allow_concurrent', [False, True]) +async def test_exception_all(flush_all, allow_concurrent): async def foo(docs, **kwargs): raise AssertionError @@ -268,6 +283,7 @@ async def foo(docs, **kwargs): preferred_batch_size=2, flush_all=flush_all, timeout=500, + allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(10)] @@ -306,8 +322,9 @@ async def foo(docs, **kwargs): @pytest.mark.parametrize('preferred_batch_size', [7, 61, 100]) @pytest.mark.parametrize('timeout', [0.3, 500]) @pytest.mark.parametrize('flush_all', [False, True]) +@pytest.mark.parametrize('allow_concurrent', [False, True]) @pytest.mark.asyncio -async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all): +async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all, allow_concurrent): import random async def foo(docs, **kwargs): @@ -326,6 +343,7 @@ async def foo(docs, **kwargs): preferred_batch_size=preferred_batch_size, flush_all=flush_all, timeout=timeout, + allow_concurrent=allow_concurrent ) data_requests = [DataRequest() for _ in range(num_requests)]