From eae2478d7486834b780b65144e4b6b6ba86428dc Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Fri, 5 Jan 2024 16:57:17 -0800 Subject: [PATCH] [tutorial] Added multi-model tutorial and refactored `LLM` model class --- .../04-multiple-models/summarize_audio.py | 60 +++++++++++++++++++ nos/models/__init__.py | 2 +- nos/models/{llama2_chat.py => llm.py} | 40 +++++++------ nos/models/whisper.py | 4 ++ .../{test_llama2_chat.py => test_llm.py} | 10 ++-- 5 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 examples/tutorials/04-multiple-models/summarize_audio.py rename nos/models/{llama2_chat.py => llm.py} (87%) rename tests/models/{test_llama2_chat.py => test_llm.py} (60%) diff --git a/examples/tutorials/04-multiple-models/summarize_audio.py b/examples/tutorials/04-multiple-models/summarize_audio.py new file mode 100644 index 00000000..2c658324 --- /dev/null +++ b/examples/tutorials/04-multiple-models/summarize_audio.py @@ -0,0 +1,60 @@ +import sys +from pathlib import Path + +import rich.console + +from nos.client import Client + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Summarize audio file.") + parser.add_argument("--filename", type=str, help="Audio file to summarize.") + args = parser.parse_args() + + console = rich.console.Console() + path = Path(args.filename) + + # Create a client + address = "[::]:50051" + print(f"Connecting to client at {address} ...") + client = Client(address) + client.WaitForServer() + + # Transcribe with Whisper + model_id = "distil-whisper/distil-small.en" + model = client.Module(model_id) + console.print() + console.print(f"[bold white]Transcribe with [yellow]{model_id}[/yellow].[/bold white]") + + # Transcribe the audio file and print the text + transcription_text = "" + print(f"Transcribing audio file: {path}") + with client.UploadFile(path) as remote_path: + response = model.transcribe(path=remote_path, batch_size=8) + for item in response["chunks"]: + transcription_text += item["text"] + sys.stdout.write(item["text"]) + sys.stdout.flush() + print() + + # Summarize the transcription with LLMs + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + llm = client.Module(model_id) + console.print() + console.print("[bold white]Summarize with [yellow]TinyLlama/TinyLlama-1.1B-Chat-v1.0[/yellow].[/bold white]") + + prompt = f""" + You are a useful transcribing assistant. + Summarize the following text concisely with key points. + Keep the sentences short, highlight key concepts in each bullet starting with a hyphen. + + {transcription_text} + """ + messages = [ + {"role": "user", "content": prompt}, + ] + for response in llm.chat(messages=messages, max_new_tokens=1024, _stream=True): + sys.stdout.write(response) + sys.stdout.flush() diff --git a/nos/models/__init__.py b/nos/models/__init__.py index b8ae530d..fd4a828a 100644 --- a/nos/models/__init__.py +++ b/nos/models/__init__.py @@ -12,7 +12,7 @@ from .blip import BLIP # noqa: F401 from .clip import CLIP # noqa: F401 from .faster_rcnn import FasterRCNN # noqa: F401 -from .llama2_chat import Llama2Chat # noqa: F401 +from .llm import LLM # noqa: F401 from .monodepth import MonoDepth # noqa: F401 from .owlvit import OwlViT # noqa: F401 from .sam import SAM diff --git a/nos/models/llama2_chat.py b/nos/models/llm.py similarity index 87% rename from nos/models/llama2_chat.py rename to nos/models/llm.py index 058f6218..1ed8d9df 100644 --- a/nos/models/llama2_chat.py +++ b/nos/models/llm.py @@ -19,7 +19,7 @@ @dataclass(frozen=True) -class Llama2ChatConfig(HuggingFaceHubConfig): +class LLMConfig(HuggingFaceHubConfig): """Llama2 chat model configuration.""" max_new_tokens: int = 2048 @@ -41,52 +41,56 @@ class Llama2ChatConfig(HuggingFaceHubConfig): """Chat template to use for the model.""" -class Llama2Chat: +class LLM: configs = { - "meta-llama/Llama-2-7b-chat-hf": Llama2ChatConfig( + "meta-llama/Llama-2-7b-chat-hf": LLMConfig( model_name="meta-llama/Llama-2-7b-chat-hf", compute_dtype="float16", needs_auth=True, chat_template=LLAMA2_CHAT_TEMPLATE, ), - "meta-llama/Llama-2-13b-chat-hf": Llama2ChatConfig( + "meta-llama/Llama-2-13b-chat-hf": LLMConfig( model_name="meta-llama/Llama-2-13b-chat-hf", compute_dtype="float16", needs_auth=True, chat_template=LLAMA2_CHAT_TEMPLATE, ), - "meta-llama/Llama-2-70b-chat-hf": Llama2ChatConfig( + "meta-llama/Llama-2-70b-chat-hf": LLMConfig( model_name="meta-llama/Llama-2-70b-chat-hf", compute_dtype="float16", needs_auth=True, chat_template=LLAMA2_CHAT_TEMPLATE, ), - "HuggingFaceH4/zephyr-7b-beta": Llama2ChatConfig( + "HuggingFaceH4/zephyr-7b-beta": LLMConfig( model_name="HuggingFaceH4/zephyr-7b-beta", compute_dtype="float16", ), - "HuggingFaceH4/tiny-random-LlamaForCausalLM": Llama2ChatConfig( + "HuggingFaceH4/tiny-random-LlamaForCausalLM": LLMConfig( model_name="HuggingFaceH4/tiny-random-LlamaForCausalLM", compute_dtype="float16", ), - "NousResearch/Yarn-Mistral-7b-128k": Llama2ChatConfig( + "NousResearch/Yarn-Mistral-7b-128k": LLMConfig( model_name="NousResearch/Yarn-Mistral-7b-128k", compute_dtype="float16", additional_kwargs={"use_flashattention_2": True, "trust_remote_code": True}, ), - "mistralai/Mistral-7B-Instruct-v0.2": Llama2ChatConfig( + "mistralai/Mistral-7B-Instruct-v0.2": LLMConfig( model_name="mistralai/Mistral-7B-Instruct-v0.2", compute_dtype="float16", ), + "TinyLlama/TinyLlama-1.1B-Chat-v1.0": LLMConfig( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + compute_dtype="float16", + ), } def __init__(self, model_name: str = "HuggingFaceH4/zephyr-7b-beta"): from nos.logging import logger try: - self.cfg = Llama2Chat.configs[model_name] + self.cfg = LLM.configs[model_name] except KeyError: - raise ValueError(f"Invalid model_name: {model_name}, available models: {Llama2ChatConfig.configs.keys()}") + raise ValueError(f"Invalid model_name: {model_name}, available models: {LLMConfig.configs.keys()}") self.device_str = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_str) @@ -150,19 +154,19 @@ def chat( start_t = time.perf_counter() if idx > 0: self.logger.debug( - f"""tok/s={idx / (time.perf_counter() - start_t):.2f}, """ - f"""memory={torch.cuda.memory_allocated(device=self.model.device) / 1024 ** 2:.2f} MB, """ - f"""allocated={torch.cuda.max_memory_allocated(device=self.model.device) / 1024 ** 2:.2f} MB, """ - f"""peak={torch.cuda.max_memory_reserved(device=self.model.device) / 1024 ** 2:.2f} MB, """ + f"""tok/s={idx / (time.perf_counter() - start_t):.1f}, """ + f"""memory={torch.cuda.memory_allocated(device=self.model.device) / 1024 ** 2:.1f} MB, """ + f"""allocated={torch.cuda.max_memory_allocated(device=self.model.device) / 1024 ** 2:.1f} MB, """ + f"""peak={torch.cuda.max_memory_reserved(device=self.model.device) / 1024 ** 2:.1f} MB, """ ) -for model_name in Llama2Chat.configs: - cfg = Llama2Chat.configs[model_name] +for model_name in LLM.configs: + cfg = LLM.configs[model_name] hub.register( model_name, TaskType.TEXT_GENERATION, - Llama2Chat, + LLM, init_args=(model_name,), method="chat", ) diff --git a/nos/models/whisper.py b/nos/models/whisper.py index e1904325..ebe29049 100644 --- a/nos/models/whisper.py +++ b/nos/models/whisper.py @@ -39,6 +39,10 @@ class Whisper: "openai/whisper-large-v3": WhisperConfig( model_name="openai/whisper-large-v3", ), + "distil-whisper/distil-small.en": WhisperConfig( + model_name="distil-whisper/distil-small.en", + torch_dtype="float16", + ), "distil-whisper/distil-medium.en": WhisperConfig( model_name="distil-whisper/distil-medium.en", torch_dtype="float16", diff --git a/tests/models/test_llama2_chat.py b/tests/models/test_llm.py similarity index 60% rename from tests/models/test_llama2_chat.py rename to tests/models/test_llm.py index c9caafd1..06b40eeb 100644 --- a/tests/models/test_llama2_chat.py +++ b/tests/models/test_llm.py @@ -1,21 +1,21 @@ """Llama2 Chat model tests and benchmarks.""" import pytest -from nos.models import Llama2Chat +from nos.models import LLM from nos.test.utils import skip_if_no_torch_cuda -SYSTEM_PROMPT = "You are NOS chat, a Llama 2 large language model (LLM) agent hosted by Autonomi AI." +SYSTEM_PROMPT = "You are NOS chat, a large language model (LLM) agent hosted by Autonomi AI." @pytest.fixture(scope="module") def model(): - MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM" - yield Llama2Chat(model_name=MODEL_NAME) + MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + yield LLM(model_name=MODEL_NAME) @skip_if_no_torch_cuda -def test_llama2_chat(model): +def test_streaming_chat(model): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "What is the meaning of life?"},