Skip to content

Commit

Permalink
Merge pull request #1 from sensein/speech_to_visemes
Browse files Browse the repository at this point in the history
adding speech_to_visemes
  • Loading branch information
fabiocat93 authored Sep 9, 2024
2 parents 86c0808 + ba4368b commit 90c38d4
Show file tree
Hide file tree
Showing 8 changed files with 520 additions and 33 deletions.
341 changes: 341 additions & 0 deletions TTS/STV/speech_to_visemes.py

Large diffs are not rendered by default.

66 changes: 57 additions & 9 deletions TTS/chatTTS_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from rich.console import Console
import torch
from .STV.speech_to_visemes import SpeechToVisemes

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
Expand All @@ -22,6 +23,7 @@ def setup(
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
viseme_flag = True
):
self.should_listen = should_listen
self.device = device
Expand All @@ -33,6 +35,9 @@ def setup(
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rnd_spk_emb,
)
self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()
self.warmup()

def warmup(self):
Expand Down Expand Up @@ -61,22 +66,65 @@ def process(self, llm_sentence):
if gen[0] is None or len(gen[0]) == 0:
self.should_listen.set()
return

# Resample the audio to 16000 Hz
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
# Ensure the audio is converted to mono (single channel)
if len(audio_chunk.shape) > 1:
audio_chunk = librosa.to_mono(audio_chunk)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

# Process visemes if viseme_flag is set
if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

# Loop through audio chunks, yielding dict for each chunk
for i in range(0, len(audio_chunk), self.chunk_size):
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
}
# Include text and visemes for the first chunk
if i == 0:
chunk_data["text"] = llm_sentence # Assuming llm_sentence is defined elsewhere
chunk_data["visemes"] = visemes

yield chunk_data
else:
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
# Ensure the audio is converted to mono (single channel)
if len(audio_chunk.shape) > 1:
audio_chunk = librosa.to_mono(audio_chunk)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
32 changes: 27 additions & 5 deletions TTS/melo_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rich.console import Console
import torch

from .STV.speech_to_visemes import SpeechToVisemes

logger = logging.getLogger(__name__)

console = Console()
Expand All @@ -28,7 +30,6 @@
"ko": "KR",
}


class MeloTTSHandler(BaseHandler):
def setup(
self,
Expand All @@ -38,6 +39,7 @@ def setup(
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
viseme_flag = True # To obtain timestamped visemes
):
self.should_listen = should_listen
self.device = device
Expand All @@ -49,6 +51,11 @@ def setup(
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize

self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()

self.warmup()

def warmup(self):
Expand Down Expand Up @@ -100,10 +107,25 @@ def process(self, llm_sentence):
return
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
29 changes: 25 additions & 4 deletions TTS/parler_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.utils.import_utils import (
is_flash_attn_2_available,
)
from .STV.speech_to_visemes import SpeechToVisemes

torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
Expand Down Expand Up @@ -47,6 +48,7 @@ def setup(
),
play_steps_s=1,
blocksize=512,
viseme_flag = True
):
self.should_listen = should_listen
self.device = device
Expand Down Expand Up @@ -78,6 +80,10 @@ def setup(
self.model.forward, mode=self.compile_mode, fullgraph=True
)

self.viseme_flag = viseme_flag
if self.viseme_flag:
self.speech_to_visemes = SpeechToVisemes()

self.warmup()

def prepare_model_inputs(
Expand Down Expand Up @@ -182,10 +188,25 @@ def process(self, llm_sentence):
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)

if self.viseme_flag:
visemes = self.speech_to_visemes.process(audio_chunk)
for viseme in visemes:
console.print(f"[blue]ASSISTANT_MOUTH_SHAPE: {viseme['viseme']} -- {viseme['timestamp']}")
else:
visemes = None

for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
chunk_data = {
"audio": np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize]))
)
}
# For the first chunk, include text and visemes
if i == 0:
chunk_data["text"] = llm_sentence
chunk_data["visemes"] = visemes
yield chunk_data

self.should_listen.set()
2 changes: 1 addition & 1 deletion arguments_classes/parler_tts_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ParlerTTSHandlerArguments:
tts_gen_max_new_tokens: int = field(
default=512,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
"help": "Maximum number of new tokens to generate in a single completion. Default is 512, which corresponds to ~6 secs"
},
)
description: str = field(
Expand Down
13 changes: 12 additions & 1 deletion connections/local_audio_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ def callback(indata, outdata, frames, time, status):
self.input_queue.put(indata.copy())
outdata[:] = 0 * outdata
else:
outdata[:] = self.output_queue.get()[:, np.newaxis]
data = self.output_queue.get()
"""
# Check if text data is present and log it
if data.get('text') is not None:
text = data['text']
logger.info(f"Text: {text}")
# Check if viseme data is present and log it
if data.get('visemes') is not None:
visemes = data['visemes']
logger.info(f"Visemes: {visemes}")
"""
outdata[:] = data['audio'][:, np.newaxis]

logger.debug("Available devices:")
logger.debug(sd.query_devices())
Expand Down
33 changes: 28 additions & 5 deletions connections/socket_sender.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import socket
from rich.console import Console
import logging
import pickle
import struct

logger = logging.getLogger(__name__)

Expand All @@ -11,7 +13,6 @@ class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""

def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
Expand All @@ -28,9 +29,31 @@ def run(self):
logger.info("sender connected")

while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
data = self.queue_in.get()
packet = {}
if 'audio' in data and data['audio'] is not None:
audio_chunk = data['audio']
packet['audio'] = data['audio']
if 'text' in data and data['text'] is not None:
packet['text'] = data['text']
if 'visemes' in data and data['visemes'] is not None:
packet['visemes'] = data['visemes']

# Serialize the packet using pickle
serialized_packet = pickle.dumps(packet)

# Compute the length of the serialized packet
packet_length = len(serialized_packet)

# Send the packet length as a 4-byte integer using struct
self.conn.sendall(struct.pack('!I', packet_length))

# Send the serialized packet
self.conn.sendall(serialized_packet)

if 'audio' in data and data['audio'] is not None:
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break

self.conn.close()
logger.info("Sender closed")
37 changes: 29 additions & 8 deletions listen_and_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from dataclasses import dataclass, field
import sounddevice as sd
from transformers import HfArgumentParser

import struct
import pickle

@dataclass
class ListenAndPlayArguments:
send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
list_play_chunk_size: int = field(
default=1024,
metadata={"help": "The size of data chunks (in bytes). Default is 1024."},
default=512,
metadata={"help": "The size of data chunks (in bytes). Default is 512."},
)
host: str = field(
default="localhost",
Expand All @@ -33,7 +34,7 @@ class ListenAndPlayArguments:
def listen_and_play(
send_rate=16000,
recv_rate=44100,
list_play_chunk_size=1024,
list_play_chunk_size=512,
host="localhost",
send_port=12345,
recv_port=12346,
Expand Down Expand Up @@ -79,9 +80,29 @@ def receive_full_chunk(conn, chunk_size):
return data

while not stop_event.is_set():
data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
if data:
recv_queue.put(data)
# Step 1: Receive the first 4 bytes to get the packet length
length_data = receive_full_chunk(recv_socket, 4)
if not length_data:
continue # Handle disconnection or data not available

# Step 2: Unpack the length (4 bytes)
packet_length = struct.unpack('!I', length_data)[0]

# Step 3: Receive the full packet based on the length
serialized_packet = receive_full_chunk(recv_socket, packet_length)
if serialized_packet:
# Step 4: Deserialize the packet using pickle
packet = pickle.loads(serialized_packet)
# Step 5: Extract the packet contents
if 'text' in packet:
pass
# print(packet['text'])
if 'visemes' in packet:
pass
# print(packet['visemes'])

# Step 6: Put the packet audio data into the queue for sending
recv_queue.put(packet['audio'].tobytes())

try:
send_stream = sd.RawInputStream(
Expand Down Expand Up @@ -123,4 +144,4 @@ def receive_full_chunk(conn, chunk_size):
if __name__ == "__main__":
parser = HfArgumentParser((ListenAndPlayArguments,))
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
listen_and_play(**vars(listen_and_play_kwargs))
listen_and_play(**vars(listen_and_play_kwargs))

0 comments on commit 90c38d4

Please sign in to comment.