Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/dialog_transformers #27

Merged
merged 8 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dist

# Created by unit tests
.pytest_cache/
/.gtm/
272 changes: 272 additions & 0 deletions ovos_audio/playback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import random
import threading
from ovos_audio.transformers import TTSTransformersService
from ovos_bus_client.message import Message
from ovos_plugin_manager.templates.tts import TTS
from ovos_utils.log import LOG
from ovos_utils.sound import play_audio
from queue import Empty
from threading import Thread
from time import time, sleep


class PlaybackThread(Thread):
"""Thread class for playing back tts audio and sending
viseme data to enclosure.
"""

def __init__(self, queue=TTS.queue, bus=None):
super(PlaybackThread, self).__init__()
self.queue = queue or TTS.queue
self._terminated = False
self._processing_queue = False
self._paused = False
self.enclosure = None
self.p = None
self._tts = []
self.bus = bus or None
self._now_playing = None
self.active_tts = None
self._started = threading.Event()
self.tts_transform = TTSTransformersService(self.bus)

@property
def is_running(self):
return self._started.is_set() and not self._terminated

def activate_tts(self, tts_id):
self.active_tts = tts_id
tts = self.get_attached_tts()
if tts:
tts.begin_audio()

def deactivate_tts(self):
if self.active_tts:
tts = self.get_attached_tts()
if tts:
tts.end_audio()
self.active_tts = None

def init(self, tts):
"""DEPRECATED! Init the TTS Playback thread."""
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
self.attach_tts(tts)
self.set_bus(tts.bus)

def set_bus(self, bus):
"""Provide bus instance to the TTS Playback thread.
Args:
bus (MycroftBusClient): bus client
"""
self.bus = bus
self.tts_transform.set_bus(bus)

@property
def tts(self):
tts = self.get_attached_tts()
if not tts and self._tts:
return self._tts[0]
return tts

@tts.setter
def tts(self, val):
self.attach_tts(val)

@property
def attached_tts(self):
return self._tts

def attach_tts(self, tts):
"""Add TTS to be cache checked."""
if tts not in self.attached_tts:
self.attached_tts.append(tts)

def detach_tts(self, tts):
"""Remove TTS from cache check."""
if tts in self.attached_tts:
self.attached_tts.remove(tts)

def get_attached_tts(self, tts_id=None):
tts_id = tts_id or self.active_tts
if not tts_id:
return
for tts in self.attached_tts:
if hasattr(tts, "tts_id"):
# opm plugin
if tts.tts_id == tts_id:
return tts

for tts in self.attached_tts:
if not hasattr(tts, "tts_id"):
# non-opm plugin
if tts.tts_name == tts_id:
return tts

def clear_queue(self):
"""Remove all pending playbacks."""
while not self.queue.empty():
self.queue.get()
try:
self.p.terminate()
except Exception:
pass

def begin_audio(self, message=None):
"""Perform beginning of speech actions."""
if self.bus:
message = message or Message("speak")
self.bus.emit(message.forward("recognizer_loop:audio_output_start"))
else:
LOG.warning("Speech started before bus was attached.")

def end_audio(self, listen, message=None):
"""Perform end of speech output actions.
Will inform the system that speech has ended and trigger the TTS's
cache checks. Listening will be triggered if requested.
Args:
listen (bool): True if listening event should be emitted
"""
if self.bus:
# Send end of speech signals to the system
message = message or Message("speak")
self.bus.emit(message.forward("recognizer_loop:audio_output_end"))
if listen:
self.bus.emit(message.forward('mycroft.mic.listen'))
else:
LOG.warning("Speech started before bus was attached.")

def on_start(self, message=None):
self.blink(0.5)
if not self._processing_queue:
self._processing_queue = True
self.begin_audio(message)

def on_end(self, listen=False, message=None):
if self._processing_queue:
self.end_audio(listen, message)
self._processing_queue = False
# Clear cache for all attached tts objects
# This is basically the only safe time
for tts in self.attached_tts:
tts.cache.curate()
self.blink(0.2)

def _play(self):
try:
data, visemes, listen, tts_id, message = self._now_playing
self.activate_tts(tts_id)
self.on_start(message)

data = self.tts_transform.transform(data, message.context)

self.p = play_audio(data)
if visemes:
self.show_visemes(visemes)
if self.p:
self.p.communicate()
self.p.wait()
self.deactivate_tts()
if self.queue.empty():
self.on_end(listen, message)
except Empty:
pass
except Exception as e:
LOG.exception(e)
if self._processing_queue:
self.on_end()
self._now_playing = None

def run(self, cb=None):
"""Thread main loop. Get audio and extra data from queue and play.

The queue messages is a tuple containing
snd_type: 'mp3' or 'wav' telling the loop what format the data is in
data: path to temporary audio data
videmes: list of visemes to display while playing
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
listen: if listening should be triggered at the end of the sentence.

Playback of audio is started and the visemes are sent over the bus
the loop then wait for the playback process to finish before starting
checking the next position in queue.

If the queue is empty the tts.end_audio() is called possibly triggering
listening.
"""
self._paused = False
self._started.set()
while not self._terminated:
while self._paused:
sleep(0.2)
try:
# HACK: we do these check to account for direct usages of TTS.queue singletons
speech_data = self.queue.get(timeout=2)
if len(speech_data) == 5 and isinstance(speech_data[-1], Message):
data, visemes, listen, tts_id, message = speech_data
else:
LOG.warning("it seems you interfacing with TTS.queue directly, this is not recommended!\n"
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
"new expected TTS.queue contents -> data, visemes, listen, tts_id, message")
if len(speech_data) == 6:
# old ovos backwards compat
_, data, visemes, ident, listen, tts_id = speech_data
elif len(speech_data) == 5:
# mycroft style
tts_id = None
_, data, visemes, ident, listen = speech_data
else:
# old mycroft style TODO can this be deprecated? its very very old
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
listen = False
tts_id = None
_, data, visemes, ident = speech_data

message = Message("speak", context={"session": {"session_id": ident}})

self._now_playing = (data, visemes, listen, tts_id, message)
self._play()
except Exception as e:
pass

def show_visemes(self, pairs):
"""Send viseme data to enclosure

Args:
pairs (list): Visime and timing pair

Returns:
bool: True if button has been pressed.
"""
if self.enclosure:
self.enclosure.mouth_viseme(time(), pairs)

def pause(self):
"""pause thread"""
self._paused = True
if self.p:
self.p.terminate()

def resume(self):
"""resume thread"""
if self._now_playing:
self._play()
self._paused = False

def clear(self):
"""Clear all pending actions for the TTS playback thread."""
self.clear_queue()

def blink(self, rate=1.0):
"""Blink mycroft's eyes"""
if self.enclosure and random.random() < rate:
self.enclosure.eyes_blink("b")

def stop(self):
"""Stop thread"""
self._now_playing = None
self._terminated = True
self.clear_queue()

def shutdown(self):
self.stop()
for tts in self.attached_tts:
self.detach_tts(tts)

def __del__(self):
self.shutdown()
28 changes: 20 additions & 8 deletions ovos_audio/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os.path
import time
from os.path import exists
from threading import Thread, Lock

from ovos_audio.audio import AudioService
from ovos_audio.playback import PlaybackThread
from ovos_audio.transformers import DialogTransformersService
from ovos_audio.tts import TTSFactory
from ovos_audio.utils import report_timing, validate_message_context
from ovos_bus_client import Message, MessageBusClient
from ovos_bus_client.session import SessionManager
from ovos_config.config import Configuration
Expand All @@ -16,10 +19,7 @@
from ovos_utils.metrics import Stopwatch
from ovos_utils.process_utils import ProcessStatus, StatusCallbackMap
from ovos_utils.sound import play_audio

from ovos_audio.audio import AudioService
from ovos_audio.tts import TTSFactory
from ovos_audio.utils import report_timing, validate_message_context
from threading import Thread, Lock


def on_ready():
Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(self, ready_hook=on_ready, error_hook=on_error,
self.bus = bus
self.status.bind(self.bus)
self.init_messagebus()
self.dialog_transform = DialogTransformersService(self.bus)
self.playback_thread = PlaybackThread(TTS.queue, self.bus)

try:
self._maybe_reload_tts()
Expand Down Expand Up @@ -270,6 +272,16 @@ def handle_speak(self, message):
stopwatch.start()

utterance = message.data['utterance']

# allow dialog transformers to rewrite speech
utt2, message.context = self.dialog_transform.transform(dialog=utterance,
context=message.context,
sess=sess)
if utterance != utt22:
LOG.debug(f"original dialog: {utterance}")
LOG.info(f"dialog transformed to: {utt2}")
utterance = utt2

listen = message.data.get('expect_response', False)
self.execute_tts(utterance, sess.session_id, listen, message)

Expand All @@ -293,7 +305,7 @@ def _maybe_reload_tts(self):
# Create new tts instance
LOG.info("(re)loading TTS engine")
self.tts = TTSFactory.create(config)
self.tts.init(self.bus)
self.tts.init(self.bus, self.playback_thread)
self._tts_hash = config.get("module", "")

# if fallback TTS is the same as main TTS dont load it
Expand Down Expand Up @@ -341,7 +353,7 @@ def _get_tts_fallback(self):
engine: config.get('tts', {}).get(engine, {})}}
self.fallback_tts = TTSFactory.create(cfg)
self.fallback_tts.validator.validate()
self.fallback_tts.init(self.bus)
self.fallback_tts.init(self.bus, self.playback_thread)

return self.fallback_tts

Expand Down
Loading
Loading