Skip to content

Commit

Permalink
Bug fixes for generation
Browse files Browse the repository at this point in the history
- Fixed TTS request form on the client
- Added full support for native and fast audio modes
- Improved documentation with new data and backstory
- Normalized model generation config params
- Added if else statements for stability and not mandatory dependencies based on the configuration
  • Loading branch information
yukiarimo committed May 14, 2024
1 parent 6a619b8 commit ff4d846
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 145 deletions.
158 changes: 93 additions & 65 deletions README.md

Large diffs are not rendered by default.

76 changes: 70 additions & 6 deletions lib/audio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import json
import os
import whisper
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from pydub import AudioSegment

model = whisper.load_model(name="tiny.en", device="cpu")
XTTS_MODEL = None

with open('static/config.json', 'r') as config_file:
config = json.load(config_file)

if config['server']['yuna_audio_mode'] == "native":
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

def transcribe_audio(audio_file):
result = model.transcribe(audio_file)
return result['text']
Expand Down Expand Up @@ -36,10 +44,66 @@ def run_tts(lang, tts_text, speaker_audio_file, output_audio):
)

out_path = f"/Users/yuki/Documents/Github/yuna-ai/static/audio/{output_audio}"
torchaudio.save(out_path, torch.tensor(out["wav"]).unsqueeze(0), 22000)
torchaudio.save(out_path, torch.tensor(out["aiff"]).unsqueeze(0), 22000)

return out_path, speaker_audio_file

def speak_text(text, reference_audio, output_audio, language="en"):
output_audio, reference_audio = run_tts(language, text, reference_audio, output_audio)
print(f"Generated audio saved at: {output_audio}")
def speak_text(text, reference_audio, output_audio, mode, language="en"):
if mode == "native":
# Split the text into sentences
sentences = text.replace("\n", " ").replace("?", "?|").replace(".", ".|").replace("...", "...|").split("|")

# Initialize variables
chunks = []
current_chunk = ""

# Iterate over the sentences
for sentence in sentences:
# Check if adding the sentence to the current chunk exceeds the character limit
if len(current_chunk) + len(sentence) <= 200:
current_chunk += sentence.strip() + " "
else:
# If the current chunk is not empty, add it to the chunks list
if current_chunk.strip():
chunks.append(current_chunk.strip())
current_chunk = sentence.strip() + " "

# Add the last chunk if it's not empty
if current_chunk.strip():
chunks.append(current_chunk.strip())

# Join small chunks together if possible
i = 0
while i < len(chunks) - 1:
if len(chunks[i]) + len(chunks[i + 1]) <= 200:
chunks[i] += " " + chunks[i + 1]
chunks.pop(i + 1)
else:
i += 1

# List to store the names of the generated audio files
audio_files = []

for i, chunk in enumerate(chunks):
audio_file = f"response_{i+1}.wav"
result = speak_text(chunk, f"/Users/yuki/Downloads/chapter2.wav", audio_file, "native")
audio_files.append("/Users/yuki/Documents/Github/yuna-ai/static/audio/" + audio_file)

# Concatenate the audio files with a 1-second pause in between
combined = AudioSegment.empty()
for audio_file in audio_files:
combined += AudioSegment.from_wav(audio_file) + AudioSegment.silent(duration=1000)

# Export the combined audio
combined.export("/Users/yuki/Documents/Github/yuna-ai/static/audio/audio.wav", format='aiff')

elif mode == "fast":
os.system(f'say -o static/audio/audio.aiff "{text}"')

print(f"Generated audio saved at: {output_audio}")

if config['server']['yuna_audio_mode'] == "native":
xtts_checkpoint = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/yuna-talk.pth"
xtts_config = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/config.json"
xtts_vocab = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/vocab.json"
load_model(xtts_checkpoint, xtts_config, xtts_vocab)
171 changes: 108 additions & 63 deletions lib/generate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import json
import re
from flask_login import current_user
from transformers import pipeline
from llama_cpp import Llama
from lib.history import ChatHistoryManager
import requests

class ChatGenerator:
def __init__(self, config):
Expand All @@ -15,82 +16,126 @@ def __init__(self, config):
n_batch=config["ai"]["batch_size"],
n_gpu_layers=config["ai"]["gpu_layers"],
verbose=False
)
) if config["server"]["yuna_text_mode"] == "native" else ""
self.classifier = pipeline("text-classification", model=f"{config['server']['agi_model_dir']}yuna-emotion") if config["ai"]["emotions"] else ""

def generate(self, chat_id, speech=False, text="", template=None, chat_history_manager=None, useHistory=True):
chat_history = chat_history_manager.load_chat_history(list({current_user.get_id()})[0], chat_id)
response = ''

max_length_all_input_and_output = self.config["ai"]["context_length"]
max_length_of_generated_tokens = self.config["ai"]["max_new_tokens"]
max_length_of_input_tokens = max_length_all_input_and_output - max_length_of_generated_tokens
if self.config["server"]["yuna_text_mode"] == "native":
max_length_all_input_and_output = self.config["ai"]["context_length"]
max_length_of_generated_tokens = self.config["ai"]["max_new_tokens"]
max_length_of_input_tokens = max_length_all_input_and_output - max_length_of_generated_tokens

# Tokenize the history and prompt
tokenized_prompt = self.model.tokenize(template.encode('utf-8'))

# Load the chat history
text_of_history = ''
history = ''

if useHistory == True:
for item in chat_history:
name = item.get('name', '')
message = item.get('message', '')
if name and message:
history += f'{name}: {message}\n'

text_of_history = f"{history}{self.config['ai']['names'][0]}: {text}\n{self.config['ai']['names'][1]}:"

tokenized_history = self.model.tokenize(text_of_history.encode('utf-8'))

# Calculate the maximum length for the history
max_length_of_history_tokens = max_length_of_input_tokens - len(tokenized_prompt)

# Crop the history to fit into the max_length_of_history_tokens counting from the end of the text
cropped_history = tokenized_history[-max_length_of_history_tokens:]

# Replace the placeholder in the prompt with the cropped history
response = template.replace('{user_msg}', self.model.detokenize(cropped_history).decode('utf-8'))

if template == None:
print('template is none')

print('00--------------------00\n', response, '\n00--------------------00')
response = self.model(
response,
stream=False,
top_k=self.config["ai"]["top_k"],
top_p=self.config["ai"]["top_p"],
temperature=self.config["ai"]["temperature"],
repeat_penalty=self.config["ai"]["repetition_penalty"],
max_tokens=self.config["ai"]["max_new_tokens"],
stop=self.config["ai"]["stop"],
)

# Assuming the dictionary is stored in a variable named 'response'
response = response['choices'][0]['text']
response = self.clearText(str(response))

if self.config["ai"]["emotions"]:
response_add = self.classifier(response)[0]['label']

# Replace words
replacement_dict = {
"anger": "*angry*",
"disgust": "*disgusted*",
"fear": "*scared*",
"joy": "*smiling*",
"neutral": "",
"sadness": "*sad*",
"surprise": "*surprised*"
}

for word, replacement in replacement_dict.items():
response_add = response_add.replace(word, replacement)

response = response + f" {response_add}"
else:
messages = []

# Tokenize the history and prompt
tokenized_prompt = self.model.tokenize(template.encode('utf-8'))

# Load the chat history
text_of_history = ''
history = ''

if useHistory == True:
for item in chat_history:
name = item.get('name', '')
message = item.get('message', '')
if name and message:
history += f'{name}: {message}\n'

# make something like "{history}Yuki: {text}\nYuna:" but with names 0 and 1 instead of Yuki and Yuna based on the config
text_of_history = f"{history}{self.config['ai']['names'][0]}: {text}\n{self.config['ai']['names'][1]}:"

tokenized_history = self.model.tokenize(text_of_history.encode('utf-8'))

# Calculate the maximum length for the history
max_length_of_history_tokens = max_length_of_input_tokens - len(tokenized_prompt)

# Crop the history to fit into the max_length_of_history_tokens counting from the end of the text
cropped_history = tokenized_history[-max_length_of_history_tokens:]

# Replace the placeholder in the prompt with the cropped history
response = template.replace('{user_msg}', self.model.detokenize(cropped_history).decode('utf-8'))

if template == None:
print('template is none')

print('00--------------------00\n', response, '\n00--------------------00')
response = self.model(
response,
stream=False,
top_k=self.config["ai"]["top_k"],
top_p=self.config["ai"]["top_p"],
temperature=self.config["ai"]["temperature"],
repeat_penalty=self.config["ai"]["repetition_penalty"],
max_tokens=self.config["ai"]["max_new_tokens"],
stop=self.config["ai"]["stop"],
)

# Assuming the dictionary is stored in a variable named 'response'
response = response['choices'][0]['text']
response = self.clearText(str(response))

if self.config["ai"]["emotions"]:
response_add = self.classifier(response)[0]['label']

# Replace words
replacement_dict = {
"anger": "*angry*",
"disgust": "*disgusted*",
"fear": "*scared*",
"joy": "*smiling*",
"neutral": "",
"sadness": "*sad*",
"surprise": "*surprised*"
role = "user" if name == self.config['ai']['names'][0] else "assistant"
messages.append({
"role": role,
"content": message
})

messages.append({
"role": "user",
"content": text
})

dataSendAPI = {
"model": "/Users/yuki/Documents/Github/yuna-ai/lib/models/yuna/yukiarimo/yuna-ai/yuna-ai-v3-q6_k.gguf",
"messages": messages,
"temperature": self.config["ai"]["temperature"],
"max_tokens": -1, # -1 for unlimited
"stop": self.config["ai"]["stop"],
"top_p": self.config["ai"]["top_p"],
"top_k": self.config["ai"]["top_k"],
"min_p": 0,
"presence_penalty": 0,
"frequency_penalty": 0,
"logit_bias": {},
"repeat_penalty": self.config["ai"]["repetition_penalty"],
"seed": self.config["ai"]["seed"]
}

for word, replacement in replacement_dict.items():
response_add = response_add.replace(word, replacement)
url = "http://localhost:1234/v1/chat/completions"
headers = {"Content-Type": "application/json"}

response = requests.post(url, headers=headers, json=dataSendAPI, stream=False)

response = response + f" {response_add}"
if response.status_code == 200:
response_json = json.loads(response.text)
response = response_json.get('choices', [{}])[0].get('message', {}).get('content', '')
else:
print(f"Request failed with status code: {response.status_code}")

if template != None:
chat_history.append({"name": "Yuki", "message": text})
Expand Down
2 changes: 1 addition & 1 deletion lib/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def list_history_files(self, username):
return history_files

def generate_speech(self, response):
speak_text(response, "/Users/yuki/Downloads/orig.wav", "response.wav")
speak_text(response, "/Users/yuki/Downloads/orig.wav", "audio.aiff", self.config['server']['yuna_audio_mode'])

def delete_message(self, username, chat_id, target_message):
chat_history = self.load_chat_history(username, chat_id)
Expand Down
10 changes: 4 additions & 6 deletions lib/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def handle_message_request(chat_generator, chat_history_manager, chat_id=None, s
response = chat_generator.generate(chat_id, speech, text, template, chat_history_manager)
return jsonify({'response': response})

@login_required
def handle_audio_request(self):
task = request.form['task']

# debug the request
print(request.form)
print(request.files)

if task == 'transcribe':
if 'audio' not in request.files:
return jsonify({'error': 'No audio file'}), 400
Expand All @@ -68,11 +71,6 @@ def handle_audio_request(self):
return jsonify({'text': result})

elif task == 'tts':
xtts_checkpoint = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/yuna-talk.pth"
xtts_config = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/config.json"
xtts_vocab = "/Users/yuki/Documents/Github/yuna-ai/lib/models/agi/yuna-talk/vocab.json"
load_model(xtts_checkpoint, xtts_config, xtts_vocab)

print("Running TTS...")
text = """Huh? Is this a mistake? I looked over at Mom and Dad. They looked…amazed. Was this for real? In the world of Oudegeuz, we have magic. I was surprised when I first awakened to it—there wasn’t any in my last world, after all."""
result = speak_text(text, "/Users/yuki/Downloads/orig.wav", "response.wav")
Expand Down
9 changes: 5 additions & 4 deletions static/js/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ function stopRecording() {
function sendAudioToServer(audioBlob) {
const formData = new FormData();
formData.append('audio', audioBlob);
formData.append('task', 'transcribe');

fetch('/audio', {
method: 'POST',
body: formData
})
.then(response => response.json())
.then(data => {
console.log('The text in video:', data.text);
// Here you can update the client with the transcription result
// For example, you could display the result in an HTML element
messageManager.sendMessage(data.text, imageData = '', url = '/message')
})
.catch(error => {
Expand Down Expand Up @@ -401,7 +399,8 @@ class HistoryManager {
return response.json();
})
.then(responseData => {
alert(responseData);
alert("New history file created successfully.");
location.reload();
})
.catch(error => {
console.error('An error occurred:', error);
Expand Down Expand Up @@ -766,6 +765,7 @@ function captureAudioViaFile() {
const formData = new FormData();

formData.append('audio', file);
formData.append('task', 'transcribe');

fetch('/audio', {
method: 'POST',
Expand Down Expand Up @@ -828,6 +828,7 @@ function captureVideoViaFile() {

const formData = new FormData();
formData.append('audio', file);
formData.append('task', 'transcribe');

fetch('/audio', {
method: 'POST',
Expand Down

0 comments on commit ff4d846

Please sign in to comment.