diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index f72ef3ac..7b230dde 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -80,4 +80,4 @@ jobs: coverage run -m unittest -v - name: Create test coverage report run: | - coverage report -m \ No newline at end of file + coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96 \ No newline at end of file diff --git a/README.md b/README.md index 1e2b051c..758c9640 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml) +[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg?branch=main)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml?query=branch:main) [![PyPI version](https://badge.fury.io/py/google-jetstream.svg)](https://badge.fury.io/py/google-jetstream) [![PyPi downloads](https://img.shields.io/pypi/dm/google-jetstream?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/google-jetstream/) [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) @@ -57,15 +57,16 @@ python -m jetstream.tools.load_tester ### Test core modules ``` # Test JetStream core orchestrator -python -m jetstream.tests.core.test_orchestrator +python -m unittest -v jetstream.tests.core.test_orchestrator # Test JetStream core server library -python -m jetstream.tests.core.test_server +python -m unittest -v jetstream.tests.core.test_server # Test mock JetStream engine implementation -python -m jetstream.tests.engine.test_mock_engine +python -m unittest -v jetstream.tests.engine.test_mock_engine # Test mock JetStream token utils -python -m jetstream.tests.engine.test_utils +python -m unittest -v jetstream.tests.engine.test_token_utils +python -m unittest -v jetstream.tests.engine.test_utils ``` diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index fa33e45d..76035bfa 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -38,12 +38,6 @@ class ServerConfig: generate_engine_create_fns: Tuple[CreateEngineFn, ...] = () interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = () - def get_slices_to_launch(self: "ServerConfig") -> str: - """Used when launching this config via xm config.""" - return ",".join( - self.prefill_slices + self.generate_slices + self.interleaved_slices - ) - @dataclasses.dataclass class InstantiatedEngines: diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index db4e7c98..aaa45549 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -85,7 +85,7 @@ import threading import time import traceback -from typing import Any, AsyncIterator, Optional, Tuple, Union, cast +from typing import Any, AsyncIterator, Optional, Tuple, cast import grpc import jax @@ -434,13 +434,6 @@ def place_request_on_prefill_queue(self, request: ActiveRequest): self._prefill_backlog.put(request, block=False) self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize()) - def _load_cache_history(self, path: str) -> Union[None, Any]: - """Loads previous kv cache for a longer conversation.""" - if path: - raise NotImplementedError - else: - return None - def _process_prefill_content( self, request: ActiveRequest, @@ -744,6 +737,60 @@ def _get_prefill_content( True, ) + def process_client_side_tokenization_response(self, response: Any): + samples = [] + for sample in response: + samples.append( + jetstream_pb2.DecodeResponse.StreamContent.Sample( + token_ids=sample.token_ids, + ) + ) + return jetstream_pb2.DecodeResponse( + stream_content=jetstream_pb2.DecodeResponse.StreamContent( + samples=samples + ) + ) + + def should_buffer_response(self, response: Any) -> bool: + for item in response: + if item.text and token_utils.is_byte_token(item.text[-1]): + # If any sample ends in bytes, this means we might still need to + # decode more bytes to compose the string. + return True + + def process_server_side_tokenization_response( + self, response: Any, buffered_response_list + ): + # Flush the buffered responses to each sample of current response. + current_response_with_flushed_buffer = list( + zip(*buffered_response_list, response) + ) + # Empty buffer: [[s0_cur], [s1_cur], ...] + # Has buffer: + # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] + current_response_with_flushed_buffer = cast( + list[list[ReturnSample]], current_response_with_flushed_buffer + ) + # Form correct sample(s) and return as StreamContent for this iteration. + samples = [] + for sample in current_response_with_flushed_buffer: + text = [] + token_ids = [] + for resp in sample: + text.extend(resp.text) + token_ids.extend(resp.token_ids) + samples.append( + jetstream_pb2.DecodeResponse.StreamContent.Sample( + text=token_utils.text_tokens_to_str(text), + token_ids=token_ids, + ) + ) + return jetstream_pb2.DecodeResponse( + stream_content=jetstream_pb2.DecodeResponse.StreamContent( + samples=samples + ) + ) + async def Decode( # pylint: disable=invalid-overridden-method self, request: jetstream_pb2.DecodeRequest, @@ -795,70 +842,24 @@ async def Decode( # pylint: disable=invalid-overridden-method # The DecodeResponse stream should consume all generated tokens in # return_channel when complete signal is received (AsyncMultifuture # promises this). - if is_client_side_tokenization: - # If is_client_side_tokenization, the client should request with token - # ids, and the JetStream server will return token ids as response. - # The client should take care of tokenization and detokenization. - async for response in active_request.return_channel: - response = cast(list[ReturnSample], response) - samples = [] - for sample in response: - samples.append( - jetstream_pb2.DecodeResponse.StreamContent.Sample( - token_ids=sample.token_ids, - ) - ) - yield jetstream_pb2.DecodeResponse( - stream_content=jetstream_pb2.DecodeResponse.StreamContent( - samples=samples - ) - ) - else: - # Buffer response mechanism is used to handle streaming - # detokenization with special character (For some edge cases with - # SentencePiece tokenizer, it requires to decode a complete sequence - # instead of a single token). - buffered_response_list = [] - async for response in active_request.return_channel: - response = cast(list[ReturnSample], response) - buffered = False - for item in response: - if item.text and token_utils.is_byte_token(item.text[-1]): - # If any sample ends in bytes, this means we might still need to - # decode more bytes to compose the string. - buffered_response_list.append(response) - buffered = True - break - if buffered: + buffered_response_list = [] + async for response in active_request.return_channel: + response = cast(list[ReturnSample], response) + if is_client_side_tokenization: + # If is_client_side_tokenization, the client should request with token + # ids, and the JetStream server will return token ids as response. + # The client should take care of tokenization and detokenization. + yield self.process_client_side_tokenization_response(response) + else: + # Buffer response mechanism is used to handle streaming + # detokenization with special character (For some edge cases with + # SentencePiece tokenizer, it requires to decode a complete sequence + # instead of a single token). + if self.should_buffer_response(response): + buffered_response_list.append(response) continue - # Flush the buffered responses to each sample of current response. - current_response_with_flushed_buffer = list( - zip(*buffered_response_list, response) - ) - # Empty buffer: [[s0_cur], [s1_cur], ...] - # Has buffer: - # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] - current_response_with_flushed_buffer = cast( - list[list[ReturnSample]], current_response_with_flushed_buffer + yield self.process_server_side_tokenization_response( + response, buffered_response_list ) # Reset buffer after flushed. buffered_response_list = [] - # Form correct sample(s) and return as StreamContent for this iteration. - samples = [] - for sample in current_response_with_flushed_buffer: - text = [] - token_ids = [] - for resp in sample: - text.extend(resp.text) - token_ids.extend(resp.token_ids) - samples.append( - jetstream_pb2.DecodeResponse.StreamContent.Sample( - text=token_utils.text_tokens_to_str(text), - token_ids=token_ids, - ) - ) - yield jetstream_pb2.DecodeResponse( - stream_content=jetstream_pb2.DecodeResponse.StreamContent( - samples=samples - ) - ) diff --git a/jetstream/tests/core/test_config_lib.py b/jetstream/tests/core/test_config_lib.py index 17b4dda9..5cd815b9 100644 --- a/jetstream/tests/core/test_config_lib.py +++ b/jetstream/tests/core/test_config_lib.py @@ -14,22 +14,18 @@ """Unit test for config_lib.py.""" -from absl.testing import absltest, parameterized +import unittest +from parameterized import parameterized from jetstream.core import config_lib -class TestConfigLib(parameterized.TestCase): +class TestConfigLib(unittest.TestCase): - @parameterized.parameters( - ("tpu=8", 8), - ("v5e-8", 8), - ("v5e=4", 4), - ("v4-8", 4), - ) + @parameterized.expand([("tpu=8", 8), ("v5e-8", 8), ("v5e=4", 4), ("v4-8", 4)]) def test_slice_to_num_chips(self, accelerator_slice, expected_num_devices): got = config_lib.slice_to_num_chips(accelerator_slice) self.assertEqual(got, expected_num_devices) - -if __name__ == "__main__": - absltest.main() + def test_get_engines_invalid(self): + with self.assertRaises(ValueError): + config_lib.get_engines(config_lib.InterleavedCPUTestServer, []) diff --git a/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py index bb13f872..49494bef 100644 --- a/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -44,6 +44,7 @@ import unittest from jetstream.core import orchestrator from jetstream.core.proto import jetstream_pb2 +from jetstream.core.utils.return_sample import ReturnSample from jetstream.engine import mock_engine @@ -131,6 +132,13 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self): driver.stop() print("Orchestrator driver stopped.") - -if __name__ == "__main__": - unittest.main() + def test_should_buffer_response(self): + driver = self._setup_driver_interleaved_mode() + client = orchestrator.LLMOrchestrator(driver=driver) + self.assertTrue( + client.should_buffer_response( + [ReturnSample(text=["<0xAB>"], token_ids=[13])] + ) + ) + driver.stop() + print("Orchestrator driver stopped.") diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 9dda982a..5d2c08dc 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -96,6 +96,5 @@ async def test_server( counter += 1 server.stop() - -if __name__ == "__main__": - unittest.main() + def test_get_devices(self): + assert len(server_lib.get_devices()) == 1 diff --git a/jetstream/tests/engine/test_mock_engine.py b/jetstream/tests/engine/test_mock_engine.py index 2a1081e8..3f112067 100644 --- a/jetstream/tests/engine/test_mock_engine.py +++ b/jetstream/tests/engine/test_mock_engine.py @@ -30,15 +30,15 @@ I.e. ['Ċ', 'Ə', 'ɖ'] when converted back with chr() """ +import unittest import jax.numpy as jnp import numpy as np from jetstream.engine import mock_engine from jetstream.engine import token_utils -from absl.testing import absltest -class EngineTest(absltest.TestCase): +class EngineTest(unittest.TestCase): def _setup(self): """Initialises a test engine.""" @@ -128,7 +128,3 @@ def test_generate(self, slot=1): token_data = sampled_tokens.get_result_at_slot(slot) tok = token_data.tokens assert tokenizer.IdToPiece(int(tok.item())) == "ɖ" - - -if __name__ == "__main__": - absltest.main() diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 63a2db77..41fff3d3 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -136,6 +136,67 @@ def test_tokenize_and_pad_np(self): ) self.assertEqual(true_length, expected_true_length) + def test_tokenize_and_pad(self): + jax.config.update("jax_platform_name", "cpu") + self.setup_sentencepiece() + s = "I believe the meaning of life is" + vocab = self.jt_tokenizer.vocab + max_prefill_length = 1024 + padded_tokens, true_length = token_utils.tokenize_and_pad( + s, + vocab, + max_prefill_length=max_prefill_length, + ) + expected_padded_tokens = jnp.array( + [1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0] + ) + expected_true_length = 8 + self.assertTrue( + jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) + self.assertEqual(true_length, expected_true_length) + + def test_pad_token_padding_less_than_zero(self): + jax.config.update("jax_platform_name", "cpu") + self.setup_sentencepiece() + s = "I believe the meaning of life is having different experiences and " + s += "enjoy everyday of my life." + vocab = self.jt_tokenizer.vocab + max_prefill_length = 16 + tokens = vocab.encode_tf(s) + padded_tokens, true_length = token_utils.pad_tokens( + tokens, + bos_id=vocab.bos_id, + pad_id=vocab.pad_id, + max_prefill_length=max_prefill_length, + ) + # Take the last N tokens if we have too many. + expected_padded_tokens = jnp.array( + [ + 278, + 6593, + 310, + 2834, + 338, + 2534, + 1422, + 27482, + 322, + 13389, + 1432, + 3250, + 310, + 590, + 2834, + 29889, + ] + ) + expected_true_length = 19 + self.assertTrue( + jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) + self.assertEqual(true_length, expected_true_length) + def test_sentencepiece_tokenizer_encode(self): self.setup_sentencepiece() s = "I believe the meaning of life is" @@ -559,7 +620,3 @@ def test_text_tokens_to_str(self): ) == "你好�\n�hello" ) - - -if __name__ == "__main__": - unittest.main() diff --git a/jetstream/tests/engine/test_utils.py b/jetstream/tests/engine/test_utils.py index ca84ed5d..32a32542 100644 --- a/jetstream/tests/engine/test_utils.py +++ b/jetstream/tests/engine/test_utils.py @@ -15,13 +15,13 @@ """Tests functionality of the token processing utils using mock engine vocab.""" import numpy as np +import unittest from jetstream.engine import engine_api from jetstream.engine import mock_utils from jetstream.engine import token_utils -from absl.testing import absltest -class UtilsTest(absltest.TestCase): +class UtilsTest(unittest.TestCase): def test_speculations_with_multi_sample_slots(self, samples_per_slot=2): # [4, 1] @@ -89,6 +89,21 @@ def test_speculations_with_multi_sample_slots(self, samples_per_slot=2): assert text_output[1] == "A" # second token is padded. np.testing.assert_equal(complete, np.array([0, 1])) - -if __name__ == "__main__": - absltest.main() + def test_mock_utils(self): + vocab = mock_utils.TestVocab() + # test encode() + with self.assertRaises(NotImplementedError): + vocab.encode("AB") + # test encode_tf() + token_ids = vocab.encode_tf("AB") + np.testing.assert_equal(token_ids, np.array([65, 66])) + # test decode() + ids = np.array([ord("A")]) + expected = "A" + result = vocab.decode(ids) + self.assertEqual(result, expected) + # test decode_tf() + ids = np.array([[ord("A")]]) + expected = ["A"] + result_tf = vocab.decode_tf(ids) + self.assertEqual(result_tf, expected)