Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating the viseme branch #2

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 341 additions & 0 deletions TTS/STV/speech_to_visemes.py

Large diffs are not rendered by default.

66 changes: 57 additions & 9 deletions TTS/chatTTS_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from rich.console import Console
import torch
from .STV.speech_to_visemes import SpeechToVisemes

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -22,6 +23,7 @@ def setup(
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
viseme_flag = True
):
self.should_listen = should_listen
self.device = device
Expand All @@ -33,6 +35,9 @@ def setup(
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rnd_spk_emb,
)
self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()
self.warmup()

def warmup(self):
Expand Down Expand Up @@ -61,22 +66,65 @@ def process(self, llm_sentence):
if gen[0] is None or len(gen[0]) == 0:
self.should_listen.set()
return

# Resample the audio to 16000 Hz
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
# Ensure the audio is converted to mono (single channel)
if len(audio_chunk.shape) > 1:
audio_chunk = librosa.to_mono(audio_chunk)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

# Process visemes if viseme_flag is set
if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

# Loop through audio chunks, yielding dict for each chunk
for i in range(0, len(audio_chunk), self.chunk_size):
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
}
# Include text and visemes for the first chunk
if i == 0:
chunk_data["text"] = llm_sentence # Assuming llm_sentence is defined elsewhere
chunk_data["visemes"] = visemes

yield chunk_data
else:
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
# Ensure the audio is converted to mono (single channel)
if len(audio_chunk.shape) > 1:
audio_chunk = librosa.to_mono(audio_chunk)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
32 changes: 27 additions & 5 deletions TTS/melo_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rich.console import Console
import torch

from .STV.speech_to_visemes import SpeechToVisemes

logger = logging.getLogger(__name__)

console = Console()
Expand All @@ -28,7 +30,6 @@
"ko": "KR",
}


class MeloTTSHandler(BaseHandler):
def setup(
self,
Expand All @@ -38,6 +39,7 @@ def setup(
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
viseme_flag = True # To obtain timestamped visemes
):
self.should_listen = should_listen
self.device = device
Expand All @@ -49,6 +51,11 @@ def setup(
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize

self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()

self.warmup()

def warmup(self):
Expand Down Expand Up @@ -100,10 +107,25 @@ def process(self, llm_sentence):
return
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
29 changes: 25 additions & 4 deletions TTS/parler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.utils.import_utils import (
is_flash_attn_2_available,
)
from .STV.speech_to_visemes import SpeechToVisemes

torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
Expand Down Expand Up @@ -47,6 +48,7 @@ def setup(
),
play_steps_s=1,
blocksize=512,
viseme_flag = True
):
self.should_listen = should_listen
self.device = device
Expand Down Expand Up @@ -78,6 +80,10 @@ def setup(
self.model.forward, mode=self.compile_mode, fullgraph=True
)

self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()

self.warmup()

def prepare_model_inputs(
Expand Down Expand Up @@ -182,10 +188,25 @@ def process(self, llm_sentence):
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
2 changes: 1 addition & 1 deletion arguments_classes/parler_tts_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ParlerTTSHandlerArguments:
tts_gen_max_new_tokens: int = field(
default=512,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
"help": "Maximum number of new tokens to generate in a single completion. Default is 512, which corresponds to ~6 secs"
},
)
description: str = field(
Expand Down
13 changes: 12 additions & 1 deletion connections/local_audio_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ def callback(indata, outdata, frames, time, status):
self.input_queue.put(indata.copy())
outdata[:] = 0 * outdata
else:
outdata[:] = self.output_queue.get()[:, np.newaxis]
data = self.output_queue.get()
"""
# Check if text data is present and log it
if data.get('text') is not None:
text = data['text']
logger.info(f"Text: {text}")
# Check if viseme data is present and log it
if data.get('visemes') is not None:
visemes = data['visemes']
logger.info(f"Visemes: {visemes}")
"""
outdata[:] = data['audio'][:, np.newaxis]

logger.debug("Available devices:")
logger.debug(sd.query_devices())
Expand Down
33 changes: 28 additions & 5 deletions connections/socket_sender.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import socket
from rich.console import Console
import logging
import pickle
import struct

logger = logging.getLogger(__name__)

Expand All @@ -11,7 +13,6 @@ class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""

def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
Expand All @@ -28,9 +29,31 @@ def run(self):
logger.info("sender connected")

while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
data = self.queue_in.get()
packet = {}
if 'audio' in data and data['audio'] is not None:
audio_chunk = data['audio']
packet['audio'] = data['audio']
if 'text' in data and data['text'] is not None:
packet['text'] = data['text']
if 'visemes' in data and data['visemes'] is not None:
packet['visemes'] = data['visemes']

# Serialize the packet using pickle
serialized_packet = pickle.dumps(packet)

# Compute the length of the serialized packet
packet_length = len(serialized_packet)

# Send the packet length as a 4-byte integer using struct
self.conn.sendall(struct.pack('!I', packet_length))

# Send the serialized packet
self.conn.sendall(serialized_packet)

if 'audio' in data and data['audio'] is not None:
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break

self.conn.close()
logger.info("Sender closed")
37 changes: 29 additions & 8 deletions listen_and_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from dataclasses import dataclass, field
import sounddevice as sd
from transformers import HfArgumentParser

import struct
import pickle

@dataclass
class ListenAndPlayArguments:
send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
list_play_chunk_size: int = field(
default=1024,
metadata={"help": "The size of data chunks (in bytes). Default is 1024."},
default=512,
metadata={"help": "The size of data chunks (in bytes). Default is 512."},
)
host: str = field(
default="localhost",
Expand All @@ -33,7 +34,7 @@ class ListenAndPlayArguments:
def listen_and_play(
send_rate=16000,
recv_rate=44100,
list_play_chunk_size=1024,
list_play_chunk_size=512,
host="localhost",
send_port=12345,
recv_port=12346,
Expand Down Expand Up @@ -79,9 +80,29 @@ def receive_full_chunk(conn, chunk_size):
return data

while not stop_event.is_set():
data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
if data:
recv_queue.put(data)
# Step 1: Receive the first 4 bytes to get the packet length
length_data = receive_full_chunk(recv_socket, 4)
if not length_data:
continue # Handle disconnection or data not available

# Step 2: Unpack the length (4 bytes)
packet_length = struct.unpack('!I', length_data)[0]

# Step 3: Receive the full packet based on the length
serialized_packet = receive_full_chunk(recv_socket, packet_length)
if serialized_packet:
# Step 4: Deserialize the packet using pickle
packet = pickle.loads(serialized_packet)
# Step 5: Extract the packet contents
if 'text' in packet:
pass
# print(packet['text'])
if 'visemes' in packet:
pass
# print(packet['visemes'])

# Step 6: Put the packet audio data into the queue for sending
recv_queue.put(packet['audio'].tobytes())

try:
send_stream = sd.RawInputStream(
Expand Down Expand Up @@ -123,4 +144,4 @@ def receive_full_chunk(conn, chunk_size):
if __name__ == "__main__":
parser = HfArgumentParser((ListenAndPlayArguments,))
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
listen_and_play(**vars(listen_and_play_kwargs))
listen_and_play(**vars(listen_and_play_kwargs))