Skip to content

Commit

Permalink
[tutorial] Added multi-model tutorial and refactored LLM model class (
Browse files Browse the repository at this point in the history
#511)

<!-- Thank you for your contribution! Please review
https://github.com/autonomi-ai/nos/blob/main/docs/CONTRIBUTING.md before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Summary

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issues

<!-- For example: "Closes #1234" -->

## Checks

- [ ] `make lint`: I've run `make lint` to lint the changes in this PR.
- [ ] `make test`: I've made sure the tests (`make test-cpu` or `make
test`) are passing.
- Additional tests:
   - [ ] Benchmark tests (when contributing new models)
   - [ ] GPU/HW tests
  • Loading branch information
spillai authored Jan 6, 2024
1 parent a75a13c commit 4eb3a3d
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 24 deletions.
60 changes: 60 additions & 0 deletions examples/tutorials/04-multiple-models/summarize_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import sys
from pathlib import Path

import rich.console

from nos.client import Client


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Summarize audio file.")
parser.add_argument("--filename", type=str, help="Audio file to summarize.")
args = parser.parse_args()

console = rich.console.Console()
path = Path(args.filename)

# Create a client
address = "[::]:50051"
print(f"Connecting to client at {address} ...")
client = Client(address)
client.WaitForServer()

# Transcribe with Whisper
model_id = "distil-whisper/distil-small.en"
model = client.Module(model_id)
console.print()
console.print(f"[bold white]Transcribe with [yellow]{model_id}[/yellow].[/bold white]")

# Transcribe the audio file and print the text
transcription_text = ""
print(f"Transcribing audio file: {path}")
with client.UploadFile(path) as remote_path:
response = model.transcribe(path=remote_path, batch_size=8)
for item in response["chunks"]:
transcription_text += item["text"]
sys.stdout.write(item["text"])
sys.stdout.flush()
print()

# Summarize the transcription with LLMs
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llm = client.Module(model_id)
console.print()
console.print("[bold white]Summarize with [yellow]TinyLlama/TinyLlama-1.1B-Chat-v1.0[/yellow].[/bold white]")

prompt = f"""
You are a useful transcribing assistant.
Summarize the following text concisely with key points.
Keep the sentences short, highlight key concepts in each bullet starting with a hyphen.
{transcription_text}
"""
messages = [
{"role": "user", "content": prompt},
]
for response in llm.chat(messages=messages, max_new_tokens=1024, _stream=True):
sys.stdout.write(response)
sys.stdout.flush()
2 changes: 1 addition & 1 deletion nos/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .blip import BLIP # noqa: F401
from .clip import CLIP # noqa: F401
from .faster_rcnn import FasterRCNN # noqa: F401
from .llama2_chat import Llama2Chat # noqa: F401
from .llm import LLM # noqa: F401
from .monodepth import MonoDepth # noqa: F401
from .owlvit import OwlViT # noqa: F401
from .sam import SAM
Expand Down
40 changes: 22 additions & 18 deletions nos/models/llama2_chat.py → nos/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


@dataclass(frozen=True)
class Llama2ChatConfig(HuggingFaceHubConfig):
class LLMConfig(HuggingFaceHubConfig):
"""Llama2 chat model configuration."""

max_new_tokens: int = 2048
Expand All @@ -41,52 +41,56 @@ class Llama2ChatConfig(HuggingFaceHubConfig):
"""Chat template to use for the model."""


class Llama2Chat:
class LLM:
configs = {
"meta-llama/Llama-2-7b-chat-hf": Llama2ChatConfig(
"meta-llama/Llama-2-7b-chat-hf": LLMConfig(
model_name="meta-llama/Llama-2-7b-chat-hf",
compute_dtype="float16",
needs_auth=True,
chat_template=LLAMA2_CHAT_TEMPLATE,
),
"meta-llama/Llama-2-13b-chat-hf": Llama2ChatConfig(
"meta-llama/Llama-2-13b-chat-hf": LLMConfig(
model_name="meta-llama/Llama-2-13b-chat-hf",
compute_dtype="float16",
needs_auth=True,
chat_template=LLAMA2_CHAT_TEMPLATE,
),
"meta-llama/Llama-2-70b-chat-hf": Llama2ChatConfig(
"meta-llama/Llama-2-70b-chat-hf": LLMConfig(
model_name="meta-llama/Llama-2-70b-chat-hf",
compute_dtype="float16",
needs_auth=True,
chat_template=LLAMA2_CHAT_TEMPLATE,
),
"HuggingFaceH4/zephyr-7b-beta": Llama2ChatConfig(
"HuggingFaceH4/zephyr-7b-beta": LLMConfig(
model_name="HuggingFaceH4/zephyr-7b-beta",
compute_dtype="float16",
),
"HuggingFaceH4/tiny-random-LlamaForCausalLM": Llama2ChatConfig(
"HuggingFaceH4/tiny-random-LlamaForCausalLM": LLMConfig(
model_name="HuggingFaceH4/tiny-random-LlamaForCausalLM",
compute_dtype="float16",
),
"NousResearch/Yarn-Mistral-7b-128k": Llama2ChatConfig(
"NousResearch/Yarn-Mistral-7b-128k": LLMConfig(
model_name="NousResearch/Yarn-Mistral-7b-128k",
compute_dtype="float16",
additional_kwargs={"use_flashattention_2": True, "trust_remote_code": True},
),
"mistralai/Mistral-7B-Instruct-v0.2": Llama2ChatConfig(
"mistralai/Mistral-7B-Instruct-v0.2": LLMConfig(
model_name="mistralai/Mistral-7B-Instruct-v0.2",
compute_dtype="float16",
),
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": LLMConfig(
model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
compute_dtype="float16",
),
}

def __init__(self, model_name: str = "HuggingFaceH4/zephyr-7b-beta"):
from nos.logging import logger

try:
self.cfg = Llama2Chat.configs[model_name]
self.cfg = LLM.configs[model_name]
except KeyError:
raise ValueError(f"Invalid model_name: {model_name}, available models: {Llama2ChatConfig.configs.keys()}")
raise ValueError(f"Invalid model_name: {model_name}, available models: {LLMConfig.configs.keys()}")

self.device_str = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(self.device_str)
Expand Down Expand Up @@ -150,19 +154,19 @@ def chat(
start_t = time.perf_counter()
if idx > 0:
self.logger.debug(
f"""tok/s={idx / (time.perf_counter() - start_t):.2f}, """
f"""memory={torch.cuda.memory_allocated(device=self.model.device) / 1024 ** 2:.2f} MB, """
f"""allocated={torch.cuda.max_memory_allocated(device=self.model.device) / 1024 ** 2:.2f} MB, """
f"""peak={torch.cuda.max_memory_reserved(device=self.model.device) / 1024 ** 2:.2f} MB, """
f"""tok/s={idx / (time.perf_counter() - start_t):.1f}, """
f"""memory={torch.cuda.memory_allocated(device=self.model.device) / 1024 ** 2:.1f} MB, """
f"""allocated={torch.cuda.max_memory_allocated(device=self.model.device) / 1024 ** 2:.1f} MB, """
f"""peak={torch.cuda.max_memory_reserved(device=self.model.device) / 1024 ** 2:.1f} MB, """
)


for model_name in Llama2Chat.configs:
cfg = Llama2Chat.configs[model_name]
for model_name in LLM.configs:
cfg = LLM.configs[model_name]
hub.register(
model_name,
TaskType.TEXT_GENERATION,
Llama2Chat,
LLM,
init_args=(model_name,),
method="chat",
)
4 changes: 4 additions & 0 deletions nos/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class Whisper:
"openai/whisper-large-v3": WhisperConfig(
model_name="openai/whisper-large-v3",
),
"distil-whisper/distil-small.en": WhisperConfig(
model_name="distil-whisper/distil-small.en",
torch_dtype="float16",
),
"distil-whisper/distil-medium.en": WhisperConfig(
model_name="distil-whisper/distil-medium.en",
torch_dtype="float16",
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_llama2_chat.py → tests/models/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""Llama2 Chat model tests and benchmarks."""
import pytest

from nos.models import Llama2Chat
from nos.models import LLM
from nos.test.utils import skip_if_no_torch_cuda


SYSTEM_PROMPT = "You are NOS chat, a Llama 2 large language model (LLM) agent hosted by Autonomi AI."
SYSTEM_PROMPT = "You are NOS chat, a large language model (LLM) agent hosted by Autonomi AI."


@pytest.fixture(scope="module")
def model():
MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
yield Llama2Chat(model_name=MODEL_NAME)
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
yield LLM(model_name=MODEL_NAME)


@skip_if_no_torch_cuda
def test_llama2_chat(model):
def test_streaming_chat(model):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "What is the meaning of life?"},
Expand Down

0 comments on commit 4eb3a3d

Please sign in to comment.