Skip to content

Commit

Permalink
Python API for speaker diarization.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 9, 2024
1 parent 59407ed commit d839378
Show file tree
Hide file tree
Showing 14 changed files with 315 additions and 9 deletions.
15 changes: 15 additions & 0 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test offline speaker diarization"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

python3 ./python-api-examples/offline-speaker-diarization.py

rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0


log "test_clustering"
pushd /tmp/
mkdir test-cluster
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
Expand Down
118 changes: 118 additions & 0 deletions python-api-examples/offline-speaker-diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Xiaomi Corporation

"""
This file shows how to use sherpa-onnx Python API for
offline/non-streaming speaker diarization.
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
python3 ./python-api-examples/offline-speaker-diarization.py
"""
from pathlib import Path

import sherpa_onnx
import soundfile as sf


def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
"""
Args:
num_speakers:
If you know the actual number of speakers in the wave file, then please
specify it. Otherwise, leave it to -1
cluster_threshold:
If num_speakers is -1, then this threshold is used for clustering.
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
"""
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
embedding_extractor_model = (
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
)

config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
model=segmentation_model
),
),
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_extractor_model
),
clustering=sherpa_onnx.FastClusteringConfig(
num_clusters=num_speakers, threshold=cluster_threshold
),
min_duration_on=0.3,
min_duration_off=0.5,
)
if not config.validate():
raise RuntimeError(
"Please check your config and make sure all required files exist"
)

return sherpa_onnx.OfflineSpeakerDiarization(config)


def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
progress = num_processed_chunk / num_total_chunks * 100
print(f"Progress: {progress:.3f}%")
return 0


def main():
wave_filename = "./0-four-speakers-zh.wav"
if not Path(wave_filename).is_file():
raise RuntimeError(f"{wave_filename} does not exist")

audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel

# Since we know there are 4 speakers in the above test wave file, we use
# num_speakers 4 here
sd = init_speaker_diarization(num_speakers=4)
if sample_rate != sd.sample_rate:
raise RuntimeError(
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
)

show_porgress = True

if show_porgress:
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
else:
result = sd.process(audio).sort_by_start_time()

for r in result:
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
# print(r) # this one is simpler


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
callback, callback_arg);
std::move(callback), callback_arg);

std::vector<int32_t> cluster_labels = clustering_.Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment {
const std::string &Text() const { return text_; }
float Duration() const { return end_ - start_; }

void SetText(const std::string &text) { text_ = text; }

std::string ToString() const;

private:
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig {
OfflineSpeakerDiarizationConfig(
const OfflineSpeakerSegmentationModelConfig &segmentation,
const SpeakerEmbeddingExtractorConfig &embedding,
const FastClusteringConfig &clustering)
const FastClusteringConfig &clustering, float min_duration_on,
float min_duration_off)
: segmentation(segmentation),
embedding(embedding),
clustering(clustering) {}
clustering(clustering),
min_duration_on(min_duration_on),
min_duration_off(min_duration_off) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND srcs
fast-clustering.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
)
endif()

Expand Down
32 changes: 32 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"

#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"

namespace sherpa_onnx {

static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
using PyClass = OfflineSpeakerDiarizationSegment;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
.def_property_readonly("start", &PyClass::Start)
.def_property_readonly("end", &PyClass::End)
.def_property_readonly("duration", &PyClass::Duration)
.def_property_readonly("speaker", &PyClass::Speaker)
.def_property("text", &PyClass::Text, &PyClass::SetText)
.def("__str__", &PyClass::ToString);
}

void PybindOfflineSpeakerDiarizationResult(py::module *m) {
PybindOfflineSpeakerDiarizationSegment(m);
using PyClass = OfflineSpeakerDiarizationResult;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def_property_readonly("num_segments", &PyClass::NumSegments)
.def("sort_by_start_time", &PyClass::SortByStartTime)
.def("sort_by_speaker", &PyClass::SortBySpeaker);
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineSpeakerDiarizationResult(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
92 changes: 92 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"

#include <string>
#include <vector>

#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"

namespace sherpa_onnx {

static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);

using PyClass = OfflineSpeakerSegmentationModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
int32_t, bool, const std::string &>(),
py::arg("pyannote"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("pyannote", &PyClass::pyannote)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
PybindOfflineSpeakerSegmentationModelConfig(m);

using PyClass = OfflineSpeakerDiarizationConfig;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
const SpeakerEmbeddingExtractorConfig &,
const FastClusteringConfig &, float, float>(),
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
.def_readwrite("segmentation", &PyClass::segmentation)
.def_readwrite("embedding", &PyClass::embedding)
.def_readwrite("clustering", &PyClass::clustering)
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}

void PybindOfflineSpeakerDiarization(py::module *m) {
PybindOfflineSpeakerDiarizationConfig(m);

using PyClass = OfflineSpeakerDiarization;
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
py::arg("config"))
.def_property_readonly("sample_rate", &PyClass::SampleRate)
.def(
"process",
[](const PyClass &self, const std::vector<float> samples,
std::function<int32_t(int32_t, int32_t)> callback) {
if (!callback) {
return self.Process(samples.data(), samples.size());
}

std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[callback](int32_t processed_chunks, int32_t num_chunks,
void *) -> int32_t {
callback(processed_chunks, num_chunks);
return 0;
};

return self.Process(samples.data(), samples.size(),
callback_wrapper);
},
py::arg("samples"), py::arg("callback") = py::none());
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-speaker-diarization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineSpeakerDiarization(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
12 changes: 8 additions & 4 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/python/csrc/fast-clustering.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#endif

namespace sherpa_onnx {
Expand Down Expand Up @@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
#endif

PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#endif

PybindAlsa(&m);
}

Expand Down
Loading

0 comments on commit d839378

Please sign in to comment.