Skip to content

Commit

Permalink
Merge pull request #1 from eustlb/open-api-fix
Browse files Browse the repository at this point in the history
Open api fix
  • Loading branch information
wuhongsheng authored Sep 20, 2024
2 parents d6b0941 + 506e61e commit e127cc7
Show file tree
Hide file tree
Showing 12 changed files with 372 additions and 175 deletions.
13 changes: 13 additions & 0 deletions Dockerfile.arm64
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3

ENV PYTHONUNBUFFERED 1

WORKDIR /usr/src/app

# Install packages
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*

COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt

COPY . .
19 changes: 16 additions & 3 deletions LLM/mlx_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@

console = Console()

WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}

class MLXLanguageModelHandler(BaseHandler):
"""
Expand Down Expand Up @@ -44,7 +52,7 @@ def setup(
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")

dummy_input_text = "Write me a poem about Machine Learning."
dummy_input_text = "Repeat the word 'home'."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]

n_steps = 2
Expand All @@ -61,6 +69,11 @@ def warmup(self):

def process(self, prompt):
logger.debug("infering language model...")
language_code = None

if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt

self.chat.append({"role": self.user_role, "content": prompt})

Expand All @@ -86,9 +99,9 @@ def process(self, prompt):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()

self.chat.append({"role": "assistant", "content": generated_text})
self.chat.append({"role": "assistant", "content": generated_text})
61 changes: 30 additions & 31 deletions LLM/openai_api_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,36 @@ def warmup(self):
logger.info(
f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
)
def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})

language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt

def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
if self.stream:
generated_text, printable_text = "", ""
for chunk in response:
new_text = chunk.choices[0].delta.content or ""
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield sentences[0], language_code
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text, language_code
else:
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text, language_code

language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt

response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
if self.stream:
generated_text, printable_text = "", ""
for chunk in response:
new_text = chunk.choices[0].delta.content or ""
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield sentences[0], language_code
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text, language_code
else:
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text, language_code
90 changes: 71 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* [Usage](#usage)
- [Docker Server approach](#docker-server)
- [Server/Client approach](#serverclient-approach)
- [Local approach](#local-approach)
- [Local approach](#local-approach-running-on-mac)
* [Command-line usage](#command-line-usage)
- [Model parameters](#model-parameters)
- [Generation parameters](#generation-parameters)
Expand Down Expand Up @@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install

### Server/Client Approach

To run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
1. Run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```

Then run the client locally to handle sending microphone input and receiving generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
2. Run the client locally to handle microphone input and receive generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```

### Running on Mac
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
### Local Approach (Mac)

1. For optimal settings on Mac:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```

You can also pass `--device mps` to have all the models set to device mps.
The local mac optimal settings set the mode to be local as explained above and change the models to:
- LightningWhisperMLX
- MLX LM
- MeloTTS
This setting:
- Adds `--device mps` to use MPS for all models.
- Sets LightningWhisperMLX for STT
- Sets MLX LM for language model
- Sets MeloTTS for TTS

### Recommended usage with Cuda

Expand All @@ -117,6 +118,57 @@ python s2s_pipeline.py \

For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).


### Multi-language Support

The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:

#### With the server version:


For automatic language detection:

```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```

Or for one language in particular, chinese in this example

```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```

#### Local Mac Setup

For automatic language detection:

```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```

Or for one language in particular, chinese in this example

```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```


## Command-line Usage

### Model Parameters
Expand Down
36 changes: 33 additions & 3 deletions STT/lightning_whisper_mlx_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@
from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
from copy import copy
import torch

logger = logging.getLogger(__name__)

console = Console()

SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]


class LightningWhisperSTTHandler(BaseHandler):
"""
Expand All @@ -19,15 +29,19 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup(
self,
model_name="distil-large-v3",
device="cuda",
device="mps",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.start_language = language
self.last_language = language

self.warmup()

def warmup(self):
Expand All @@ -46,10 +60,26 @@ def process(self, spoken_prompt):
global pipeline_start
pipeline_start = perf_counter()

pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code

pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()

logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")

yield pred_text
yield (pred_text, language_code)
4 changes: 2 additions & 2 deletions TTS/melo_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
console = Console()

WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"en": "EN",
"fr": "FR",
"es": "ES",
"zh": "ZH",
Expand All @@ -20,7 +20,7 @@
}

WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"en": "EN-BR",
"fr": "FR",
"es": "ES",
"zh": "ZH",
Expand Down
5 changes: 4 additions & 1 deletion TTS/parler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup(

if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
"Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
)
self.compile_mode = "default"

Expand Down Expand Up @@ -147,6 +147,9 @@ def warmup(self):
)

def process(self, llm_sentence):
if isinstance(llm_sentence, tuple):
llm_sentence, _ = llm_sentence

console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)

Expand Down
4 changes: 4 additions & 0 deletions VAD/vad_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ def process(self, audio_chunk):
)
array = enhanced.numpy().squeeze()
yield array

@property
def min_time_to_debug(self):
return 0.00001
2 changes: 1 addition & 1 deletion arguments_classes/vad_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ class VADHandlerArguments:
audio_enhancement: bool = field(
default=False,
metadata={
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True."
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
},
)
7 changes: 6 additions & 1 deletion baseHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def run(self):
start_time = perf_counter()
for output in self.process(input):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
if self.last_time > self.min_time_to_debug:
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
self.queue_out.put(output)
start_time = perf_counter()

Expand All @@ -46,6 +47,10 @@ def run(self):
@property
def last_time(self):
return self._times[-1]

@property
def min_time_to_debug(self):
return 0.001

def cleanup(self):
pass
Loading

0 comments on commit e127cc7

Please sign in to comment.