Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit test coverage cleanup #81

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ jobs:
coverage run -m unittest -v
- name: Create test coverage report
run: |
coverage report -m
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml)
[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg?branch=main)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml?query=branch:main)
[![PyPI version](https://badge.fury.io/py/google-jetstream.svg)](https://badge.fury.io/py/google-jetstream)
[![PyPi downloads](https://img.shields.io/pypi/dm/google-jetstream?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/google-jetstream/)
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
Expand Down Expand Up @@ -57,15 +57,16 @@ python -m jetstream.tools.load_tester
### Test core modules
```
# Test JetStream core orchestrator
python -m jetstream.tests.core.test_orchestrator
python -m unittest -v jetstream.tests.core.test_orchestrator

# Test JetStream core server library
python -m jetstream.tests.core.test_server
python -m unittest -v jetstream.tests.core.test_server

# Test mock JetStream engine implementation
python -m jetstream.tests.engine.test_mock_engine
python -m unittest -v jetstream.tests.engine.test_mock_engine

# Test mock JetStream token utils
python -m jetstream.tests.engine.test_utils
python -m unittest -v jetstream.tests.engine.test_token_utils
python -m unittest -v jetstream.tests.engine.test_utils

```
6 changes: 0 additions & 6 deletions jetstream/core/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ class ServerConfig:
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()

def get_slices_to_launch(self: "ServerConfig") -> str:
"""Used when launching this config via xm config."""
return ",".join(
self.prefill_slices + self.generate_slices + self.interleaved_slices
)


@dataclasses.dataclass
class InstantiatedEngines:
Expand Down
143 changes: 72 additions & 71 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
import threading
import time
import traceback
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast
from typing import Any, AsyncIterator, Optional, Tuple, cast

import grpc
import jax
Expand Down Expand Up @@ -434,13 +434,6 @@ def place_request_on_prefill_queue(self, request: ActiveRequest):
self._prefill_backlog.put(request, block=False)
self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize())

def _load_cache_history(self, path: str) -> Union[None, Any]:
"""Loads previous kv cache for a longer conversation."""
if path:
raise NotImplementedError
else:
return None

def _process_prefill_content(
self,
request: ActiveRequest,
Expand Down Expand Up @@ -744,6 +737,60 @@ def _get_prefill_content(
True,
)

def process_client_side_tokenization_response(self, response: Any):
samples = []
for sample in response:
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
token_ids=sample.token_ids,
)
)
return jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)

def should_buffer_response(self, response: Any) -> bool:
for item in response:
if item.text and token_utils.is_byte_token(item.text[-1]):
# If any sample ends in bytes, this means we might still need to
# decode more bytes to compose the string.
return True

def process_server_side_tokenization_response(
self, response: Any, buffered_response_list
):
# Flush the buffered responses to each sample of current response.
current_response_with_flushed_buffer = list(
zip(*buffered_response_list, response)
)
# Empty buffer: [[s0_cur], [s1_cur], ...]
# Has buffer:
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
current_response_with_flushed_buffer = cast(
list[list[ReturnSample]], current_response_with_flushed_buffer
)
# Form correct sample(s) and return as StreamContent for this iteration.
samples = []
for sample in current_response_with_flushed_buffer:
text = []
token_ids = []
for resp in sample:
text.extend(resp.text)
token_ids.extend(resp.token_ids)
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
text=token_utils.text_tokens_to_str(text),
token_ids=token_ids,
)
)
return jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)

async def Decode( # pylint: disable=invalid-overridden-method
self,
request: jetstream_pb2.DecodeRequest,
Expand Down Expand Up @@ -795,70 +842,24 @@ async def Decode( # pylint: disable=invalid-overridden-method
# The DecodeResponse stream should consume all generated tokens in
# return_channel when complete signal is received (AsyncMultifuture
# promises this).
if is_client_side_tokenization:
# If is_client_side_tokenization, the client should request with token
# ids, and the JetStream server will return token ids as response.
# The client should take care of tokenization and detokenization.
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
samples = []
for sample in response:
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
token_ids=sample.token_ids,
)
)
yield jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)
else:
# Buffer response mechanism is used to handle streaming
# detokenization with special character (For some edge cases with
# SentencePiece tokenizer, it requires to decode a complete sequence
# instead of a single token).
buffered_response_list = []
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
buffered = False
for item in response:
if item.text and token_utils.is_byte_token(item.text[-1]):
# If any sample ends in bytes, this means we might still need to
# decode more bytes to compose the string.
buffered_response_list.append(response)
buffered = True
break
if buffered:
buffered_response_list = []
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
if is_client_side_tokenization:
# If is_client_side_tokenization, the client should request with token
# ids, and the JetStream server will return token ids as response.
# The client should take care of tokenization and detokenization.
yield self.process_client_side_tokenization_response(response)
else:
# Buffer response mechanism is used to handle streaming
# detokenization with special character (For some edge cases with
# SentencePiece tokenizer, it requires to decode a complete sequence
# instead of a single token).
if self.should_buffer_response(response):
buffered_response_list.append(response)
continue
# Flush the buffered responses to each sample of current response.
current_response_with_flushed_buffer = list(
zip(*buffered_response_list, response)
)
# Empty buffer: [[s0_cur], [s1_cur], ...]
# Has buffer:
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
current_response_with_flushed_buffer = cast(
list[list[ReturnSample]], current_response_with_flushed_buffer
yield self.process_server_side_tokenization_response(
response, buffered_response_list
)
# Reset buffer after flushed.
buffered_response_list = []
# Form correct sample(s) and return as StreamContent for this iteration.
samples = []
for sample in current_response_with_flushed_buffer:
text = []
token_ids = []
for resp in sample:
text.extend(resp.text)
token_ids.extend(resp.token_ids)
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
text=token_utils.text_tokens_to_str(text),
token_ids=token_ids,
)
)
yield jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)
18 changes: 7 additions & 11 deletions jetstream/tests/core/test_config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,18 @@

"""Unit test for config_lib.py."""

from absl.testing import absltest, parameterized
import unittest
from parameterized import parameterized
from jetstream.core import config_lib


class TestConfigLib(parameterized.TestCase):
class TestConfigLib(unittest.TestCase):

@parameterized.parameters(
("tpu=8", 8),
("v5e-8", 8),
("v5e=4", 4),
("v4-8", 4),
)
@parameterized.expand([("tpu=8", 8), ("v5e-8", 8), ("v5e=4", 4), ("v4-8", 4)])
def test_slice_to_num_chips(self, accelerator_slice, expected_num_devices):
got = config_lib.slice_to_num_chips(accelerator_slice)
self.assertEqual(got, expected_num_devices)


if __name__ == "__main__":
absltest.main()
def test_get_engines_invalid(self):
with self.assertRaises(ValueError):
config_lib.get_engines(config_lib.InterleavedCPUTestServer, [])
14 changes: 11 additions & 3 deletions jetstream/tests/core/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import unittest
from jetstream.core import orchestrator
from jetstream.core.proto import jetstream_pb2
from jetstream.core.utils.return_sample import ReturnSample
from jetstream.engine import mock_engine


Expand Down Expand Up @@ -131,6 +132,13 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
driver.stop()
print("Orchestrator driver stopped.")


if __name__ == "__main__":
unittest.main()
def test_should_buffer_response(self):
driver = self._setup_driver_interleaved_mode()
client = orchestrator.LLMOrchestrator(driver=driver)
self.assertTrue(
client.should_buffer_response(
[ReturnSample(text=["<0xAB>"], token_ids=[13])]
)
)
driver.stop()
print("Orchestrator driver stopped.")
5 changes: 2 additions & 3 deletions jetstream/tests/core/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,5 @@ async def test_server(
counter += 1
server.stop()


if __name__ == "__main__":
unittest.main()
def test_get_devices(self):
assert len(server_lib.get_devices()) == 1
8 changes: 2 additions & 6 deletions jetstream/tests/engine/test_mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
I.e. ['Ċ', 'Ə', 'ɖ'] when converted back with chr()
"""

import unittest
import jax.numpy as jnp
import numpy as np

from jetstream.engine import mock_engine
from jetstream.engine import token_utils
from absl.testing import absltest


class EngineTest(absltest.TestCase):
class EngineTest(unittest.TestCase):

def _setup(self):
"""Initialises a test engine."""
Expand Down Expand Up @@ -128,7 +128,3 @@ def test_generate(self, slot=1):
token_data = sampled_tokens.get_result_at_slot(slot)
tok = token_data.tokens
assert tokenizer.IdToPiece(int(tok.item())) == "ɖ"


if __name__ == "__main__":
absltest.main()
65 changes: 61 additions & 4 deletions jetstream/tests/engine/test_token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,67 @@ def test_tokenize_and_pad_np(self):
)
self.assertEqual(true_length, expected_true_length)

def test_tokenize_and_pad(self):
jax.config.update("jax_platform_name", "cpu")
self.setup_sentencepiece()
s = "I believe the meaning of life is"
vocab = self.jt_tokenizer.vocab
max_prefill_length = 1024
padded_tokens, true_length = token_utils.tokenize_and_pad(
s,
vocab,
max_prefill_length=max_prefill_length,
)
expected_padded_tokens = jnp.array(
[1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0]
)
expected_true_length = 8
self.assertTrue(
jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7)
)
self.assertEqual(true_length, expected_true_length)

def test_pad_token_padding_less_than_zero(self):
jax.config.update("jax_platform_name", "cpu")
self.setup_sentencepiece()
s = "I believe the meaning of life is having different experiences and "
s += "enjoy everyday of my life."
vocab = self.jt_tokenizer.vocab
max_prefill_length = 16
tokens = vocab.encode_tf(s)
padded_tokens, true_length = token_utils.pad_tokens(
tokens,
bos_id=vocab.bos_id,
pad_id=vocab.pad_id,
max_prefill_length=max_prefill_length,
)
# Take the last N tokens if we have too many.
expected_padded_tokens = jnp.array(
[
278,
6593,
310,
2834,
338,
2534,
1422,
27482,
322,
13389,
1432,
3250,
310,
590,
2834,
29889,
]
)
expected_true_length = 19
self.assertTrue(
jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7)
)
self.assertEqual(true_length, expected_true_length)

def test_sentencepiece_tokenizer_encode(self):
self.setup_sentencepiece()
s = "I believe the meaning of life is"
Expand Down Expand Up @@ -559,7 +620,3 @@ def test_text_tokens_to_str(self):
)
== "你好�\n�hello"
)


if __name__ == "__main__":
unittest.main()
Loading
Loading