Skip to content

Commit

Permalink
fix: compatible_engineにsafety requirementとアサートを入れる
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Nov 7, 2024
1 parent 006b4f0 commit 601acd9
Showing 1 changed file with 56 additions and 9 deletions.
65 changes: 56 additions & 9 deletions crates/voicevox_core_c_api/src/compatible_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,31 @@ pub extern "C" fn supported_devices() -> *const c_char {
});
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn yukarin_s_forward(
/// # Safety
///
/// - `phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn yukarin_s_forward(
length: i64,
phoneme_list: *mut i64,
speaker_id: *mut i64,
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(phoneme_list);
assert_aligned(speaker_id);
assert_aligned(output);
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).predict_duration(
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts_mut(phoneme_list, length as usize) },
StyleId::new(unsafe { *speaker_id as u32 }),
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length as usize) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -248,9 +257,18 @@ pub extern "C" fn yukarin_s_forward(
}
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn yukarin_sa_forward(
/// # Safety
///
/// - `vowel_phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `consonant_phoneme_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `start_accent_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `end_accent_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `start_accent_phrase_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `end_accent_phrase_list`はRustの`&[i64; length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn yukarin_sa_forward(
length: i64,
vowel_phoneme_list: *mut i64,
consonant_phoneme_list: *mut i64,
Expand All @@ -262,9 +280,18 @@ pub extern "C" fn yukarin_sa_forward(
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(vowel_phoneme_list);
assert_aligned(consonant_phoneme_list);
assert_aligned(start_accent_list);
assert_aligned(end_accent_list);
assert_aligned(start_accent_phrase_list);
assert_aligned(end_accent_phrase_list);
assert_aligned(speaker_id);
assert_aligned(output);
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).predict_intonation(
length as usize,
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts(vowel_phoneme_list, length as usize) },
unsafe { std::slice::from_raw_parts(consonant_phoneme_list, length as usize) },
unsafe { std::slice::from_raw_parts(start_accent_list, length as usize) },
Expand All @@ -275,6 +302,7 @@ pub extern "C" fn yukarin_sa_forward(
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length as usize) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -286,9 +314,14 @@ pub extern "C" fn yukarin_sa_forward(
}
}

// SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[unsafe(no_mangle)]
pub extern "C" fn decode_forward(
/// # Safety
///
/// - `f0`はRustの`&[f32; length as usize]`として解釈できなければならない。
/// - `phoneme`はRustの`&[f32; phoneme_size * length as usize]`として解釈できなければならない。
/// - `speaker_id`はRustの`&[i64; 1]`として解釈できなければならない。
/// - `output`はRustの`&mut [f32; length as usize * 256]`として解釈できなければならない。
#[unsafe(no_mangle)] // SAFETY: voicevox_core_c_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
pub unsafe extern "C" fn decode_forward(
length: i64,
phoneme_size: i64,
f0: *mut f32,
Expand All @@ -297,18 +330,24 @@ pub extern "C" fn decode_forward(
output: *mut f32,
) -> bool {
init_logger_once();
assert_aligned(f0);
assert_aligned(phoneme);
assert_aligned(speaker_id);
assert_aligned(output);
let length = length as usize;
let phoneme_size = phoneme_size as usize;
let synthesizer = &*lock_synthesizer();
let result = ensure_initialized!(synthesizer).decode(
length,
phoneme_size,
// SAFETY: The safety contract must be upheld by the caller.
unsafe { std::slice::from_raw_parts(f0, length) },
unsafe { std::slice::from_raw_parts(phoneme, phoneme_size * length) },
StyleId::new(unsafe { *speaker_id as u32 }),
);
match result {
Ok(output_vec) => {
// SAFETY: The safety contract must be upheld by the caller.
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length * 256) };
output_slice.clone_from_slice(&output_vec);
true
Expand All @@ -319,3 +358,11 @@ pub extern "C" fn decode_forward(
}
}
}

#[track_caller]
fn assert_aligned(ptr: *mut impl Sized) {
assert!(
ptr.is_aligned(),
"all of the pointers passed to this library **must** be aligned",
);
}

0 comments on commit 601acd9

Please sign in to comment.