Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

silero: adjust vad activation threshold #639

Merged
merged 15 commits into from
Aug 17, 2024
6 changes: 6 additions & 0 deletions .changeset/kind-cougars-live.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-silero": patch
---

silero: adjust vad activation threshold
1 change: 1 addition & 0 deletions examples/voice-assistant/minimal_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def entrypoint(ctx: JobContext):
llm=openai.LLM(),
tts=openai.TTS(),
chat_ctx=initial_ctx,
plotting=True,
)
assistant.start(ctx.room)

Expand Down
51 changes: 38 additions & 13 deletions livekit-agents/livekit/agents/voice_assistant/plotter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import contextlib
import io
import multiprocessing as mp
import select
import socket
import time
from dataclasses import dataclass
from typing import ClassVar, Literal, Tuple

from .. import utils
from ..ipc import channel

PlotType = Literal["vad_probability", "raw_vol", "smoothed_vol"]
Expand Down Expand Up @@ -57,7 +61,7 @@ def read(self, b: io.BytesIO) -> None:
}


def _draw_plot(reader):
def _draw_plot(mp_cch):
try:
import matplotlib as mpl # type: ignore
import matplotlib.pyplot as plt # type: ignore
Expand All @@ -77,11 +81,17 @@ def _draw_plot(reader):

max_points = 250

plot_rx = channel.ProcChannel(conn=reader, messages=PLT_MESSAGES)
duplex = utils.aio.duplex_unix._Duplex.open(mp_cch)

poller = select.poll()
poller.register(mp_cch, select.POLLIN)

def _draw_cb(sp, pv):
while reader.poll():
msg = plot_rx.recv()
while True:
if not poller.poll(20):
break

msg = channel.recv_message(duplex, PLT_MESSAGES)
if isinstance(msg, PlotMessage):
data = plot_data.setdefault(msg.which, ([], []))
data[0].append(msg.x)
Expand Down Expand Up @@ -129,7 +139,7 @@ def _draw_cb(sp, pv):

fig.canvas.draw()

timer = fig.canvas.new_timer(interval=150)
timer = fig.canvas.new_timer(interval=33)
timer.add_callback(_draw_cb, sp, pv)
timer.start()
plt.show()
Expand All @@ -140,36 +150,51 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._started = False

def start(self):
async def start(self):
if self._started:
return

mp_pch, mp_cch = mp.Pipe(duplex=True)
self._plot_tx = channel.AsyncProcChannel(
conn=mp_pch, loop=self._loop, messages=PLT_MESSAGES
)
mp_pch, mp_cch = socket.socketpair()
self._duplex = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch)
self._plot_proc = mp.Process(target=_draw_plot, args=(mp_cch,), daemon=True)
self._plot_proc.start()
mp_cch.close()

self._started = True
self._closed = False
self._start_time = time.time()

def plot_value(self, which: PlotType, y: float):
if not self._started:
return

ts = time.time() - self._start_time
asyncio.ensure_future(self._plot_tx.asend(PlotMessage(which=which, x=ts, y=y)))
self._send_message(PlotMessage(which=which, x=ts, y=y))

def plot_event(self, which: EventType):
if not self._started:
return

ts = time.time() - self._start_time
asyncio.ensure_future(self._plot_tx.asend(PlotEventMessage(which=which, x=ts)))
self._send_message(PlotEventMessage(which=which, x=ts))

def _send_message(self, msg: channel.Message) -> None:
if self._closed:
return

async def _asend_message():
try:
await channel.asend_message(self._duplex, msg)
except Exception:
self._closed = True

def terminate(self):
asyncio.ensure_future(_asend_message())

async def terminate(self):
if not self._started:
return

self._plot_proc.terminate()

with contextlib.suppress(utils.aio.duplex_unix.DuplexClosed):
await self._duplex.aclose()
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _on_final_transcript(ev: stt.SpeechEvent) -> None:
@utils.log_exceptions(logger=logger)
async def _main_task(self) -> None:
if self._opts.plotting:
self._plotter.start()
await self._plotter.start()

audio_source = rtc.AudioSource(self._tts.sample_rate, self._tts.num_channels)
track = rtc.LocalAudioTrack.create_audio_track("assistant_voice", audio_source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load(
min_silence_duration: float = 0.25,
padding_duration: float = 0.1,
max_buffered_speech: float = 60.0,
activation_threshold: float = 0.25,
activation_threshold: float = 0.5,
sample_rate: int = 16000,
force_cpu: bool = True,
) -> "VAD":
Expand Down
Loading