Skip to content

Commit

Permalink
ONNX Runtimeとモデルのシグネチャを隔離する (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Nov 16, 2023
1 parent 302437f commit 68edcc9
Show file tree
Hide file tree
Showing 21 changed files with 1,610 additions and 639 deletions.
59 changes: 59 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 3 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
[workspace]
members = [
"crates/downloader",
"crates/test_util",
"crates/voicevox_core",
"crates/voicevox_core_c_api",
"crates/voicevox_core_java_api",
"crates/voicevox_core_python_api",
"crates/xtask"
]
members = ["crates/*"]
resolver = "2"

[workspace.dependencies]
Expand All @@ -18,7 +10,9 @@ derive_more = "0.99.17"
easy-ext = "1.0.1"
fs-err = { version = "2.9.0", features = ["tokio"] }
futures = "0.3.26"
indexmap = { version = "2.0.0", features = ["serde"] }
itertools = "0.10.5"
ndarray = "0.15.6"
once_cell = "1.18.0"
regex = "1.10.0"
rstest = "0.15.0"
Expand Down
6 changes: 5 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ derive-new = "0.5.9"
derive_more.workspace = true
duplicate = "1.0.0"
easy-ext.workspace = true
educe = "0.4.23"
enum-map = "3.0.0-beta.1"
fs-err.workspace = true
futures.workspace = true
indexmap = { version = "2.0.0", features = ["serde"] }
indexmap.workspace = true
itertools.workspace = true
nanoid = "0.4.0"
ndarray.workspace = true
once_cell.workspace = true
regex.workspace = true
serde.workspace = true
Expand All @@ -31,6 +34,7 @@ thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
uuid.workspace = true
voicevox_core_macros = { path = "../voicevox_core_macros" }

[dependencies.onnxruntime]
git = "https://github.com/VOICEVOX/onnxruntime-rs.git"
Expand Down
26 changes: 5 additions & 21 deletions crates/voicevox_core/src/devices.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};

use super::*;
use crate::{infer::InferenceRuntime, synthesizer::InferenceRuntimeImpl};

/// このライブラリで利用可能なデバイスの情報。
///
Expand All @@ -11,21 +12,21 @@ pub struct SupportedDevices {
/// CPUが利用可能。
///
/// 常に`true`。
cpu: bool,
pub cpu: bool,
/// CUDAが利用可能。
///
/// ONNX Runtimeの[CUDA Execution Provider] (`CUDAExecutionProvider`)に対応する。必要な環境につ
/// いてはそちらを参照。
///
/// [CUDA Execution Provider]: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
cuda: bool,
pub cuda: bool,
/// DirectMLが利用可能。
///
/// ONNX Runtimeの[DirectML Execution Provider] (`DmlExecutionProvider`)に対応する。必要な環境に
/// ついてはそちらを参照。
///
/// [DirectML Execution Provider]: https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html
dml: bool,
pub dml: bool,
}

impl SupportedDevices {
Expand All @@ -42,24 +43,7 @@ impl SupportedDevices {
/// # Result::<_, anyhow::Error>::Ok(())
/// ```
pub fn create() -> Result<Self> {
let mut cuda_support = false;
let mut dml_support = false;
for provider in onnxruntime::session::get_available_providers()
.map_err(ErrorRepr::GetSupportedDevices)?
.iter()
{
match provider.as_str() {
"CUDAExecutionProvider" => cuda_support = true,
"DmlExecutionProvider" => dml_support = true,
_ => {}
}
}

Ok(SupportedDevices {
cpu: true,
cuda: cuda_support,
dml: dml_support,
})
<InferenceRuntimeImpl as InferenceRuntime>::supported_devices()
}

pub fn to_json(&self) -> serde_json::Value {
Expand Down
54 changes: 18 additions & 36 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use super::full_context_label::Utterance;
use super::open_jtalk::OpenJtalk;
use super::*;
use crate::infer::InferenceRuntime;
use crate::numerics::F32Ext as _;
use crate::InferenceCore;

Expand All @@ -14,19 +15,16 @@ const MORA_PHONEME_LIST: &[&str] = &[
"a", "i", "u", "e", "o", "N", "A", "I", "U", "E", "O", "cl", "pau",
];

pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

#[derive(new)]
pub struct SynthesisEngine {
inference_core: InferenceCore,
pub(crate) struct SynthesisEngine<R: InferenceRuntime> {
inference_core: InferenceCore<R>,
open_jtalk: Arc<OpenJtalk>,
}

#[allow(unsafe_code)]
unsafe impl Send for SynthesisEngine {}

impl SynthesisEngine {
pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

pub fn inference_core(&self) -> &InferenceCore {
impl<R: InferenceRuntime> SynthesisEngine<R> {
pub fn inference_core(&self) -> &InferenceCore<R> {
&self.inference_core
}

Expand Down Expand Up @@ -123,7 +121,7 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let (_, _, vowel_indexes_data) = split_mora(&phoneme_data_list);

Expand Down Expand Up @@ -185,36 +183,20 @@ impl SynthesisEngine {
accent_phrases: &[AccentPhraseModel],
style_id: StyleId,
) -> Result<Vec<AccentPhraseModel>> {
let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases);
let (_, phoneme_data_list) = Self::initial_process(accent_phrases);

let mut base_start_accent_list = vec![0];
let mut base_end_accent_list = vec![0];
let mut base_start_accent_phrase_list = vec![0];
let mut base_end_accent_phrase_list = vec![0];
for accent_phrase in accent_phrases {
let mut accent = usize::from(*accent_phrase.accent() != 1);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_list,
accent_phrase,
accent as i32,
);
Self::create_one_accent_list(&mut base_start_accent_list, accent_phrase, accent as i32);

accent = *accent_phrase.accent() - 1;
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_list,
accent_phrase,
accent as i32,
);
SynthesisEngine::create_one_accent_list(
&mut base_start_accent_phrase_list,
accent_phrase,
0,
);
SynthesisEngine::create_one_accent_list(
&mut base_end_accent_phrase_list,
accent_phrase,
-1,
);
Self::create_one_accent_list(&mut base_end_accent_list, accent_phrase, accent as i32);
Self::create_one_accent_list(&mut base_start_accent_phrase_list, accent_phrase, 0);
Self::create_one_accent_list(&mut base_end_accent_phrase_list, accent_phrase, -1);
}
base_start_accent_list.push(0);
base_end_accent_list.push(0);
Expand Down Expand Up @@ -328,7 +310,7 @@ impl SynthesisEngine {
query.accent_phrases().clone()
};

let (flatten_moras, phoneme_data_list) = SynthesisEngine::initial_process(&accent_phrases);
let (flatten_moras, phoneme_data_list) = Self::initial_process(&accent_phrases);

let mut phoneme_length_list = vec![pre_phoneme_length];
let mut f0_list = vec![0.];
Expand Down Expand Up @@ -440,7 +422,7 @@ impl SynthesisEngine {
let num_channels: u16 = if output_stereo { 2 } else { 1 };
let bit_depth: u16 = 16;
let repeat_count: u32 =
(output_sampling_rate / Self::DEFAULT_SAMPLING_RATE) * num_channels as u32;
(output_sampling_rate / DEFAULT_SAMPLING_RATE) * num_channels as u32;
let block_size: u16 = bit_depth * num_channels / 8;

let bytes_size = wave.len() as u32 * repeat_count * 2;
Expand Down Expand Up @@ -647,12 +629,12 @@ mod tests {
use ::test_util::OPEN_JTALK_DIC_DIR;
use pretty_assertions::assert_eq;

use crate::*;
use crate::{synthesizer::InferenceRuntimeImpl, *};

#[rstest]
#[tokio::test]
async fn is_openjtalk_dict_loaded_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();
let synthesis_engine =
SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into());

Expand All @@ -662,7 +644,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn create_accent_phrases_works() {
let core = InferenceCore::new(false, 0).unwrap();
let core = InferenceCore::<InferenceRuntimeImpl>::new(false, 0).unwrap();

let model = &VoiceModel::sample().await.unwrap();
core.load_model(model).await.unwrap();
Expand Down
3 changes: 1 addition & 2 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use self::engine::{FullContextLabelError, KanaParseError};
use super::*;
//use engine::
use duplicate::duplicate_item;
use onnxruntime::OrtError;
use std::path::PathBuf;
use thiserror::Error;
use uuid::Uuid;
Expand Down Expand Up @@ -65,7 +64,7 @@ pub(crate) enum ErrorRepr {
LoadModel(#[from] LoadModelError),

#[error("サポートされているデバイス情報取得中にエラーが発生しました")]
GetSupportedDevices(#[source] OrtError),
GetSupportedDevices(#[source] anyhow::Error),

#[error(
"`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\
Expand Down
Loading

0 comments on commit 68edcc9

Please sign in to comment.