From 68edcc959aa5c730e618d16fbdaef7f6ddc1e33e Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 16 Nov 2023 11:48:48 +0900 Subject: [PATCH] =?UTF-8?q?ONNX=20Runtime=E3=81=A8=E3=83=A2=E3=83=87?= =?UTF-8?q?=E3=83=AB=E3=81=AE=E3=82=B7=E3=82=B0=E3=83=8D=E3=83=81=E3=83=A3?= =?UTF-8?q?=E3=82=92=E9=9A=94=E9=9B=A2=E3=81=99=E3=82=8B=20(#675)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 59 ++ Cargo.toml | 12 +- crates/voicevox_core/Cargo.toml | 6 +- crates/voicevox_core/src/devices.rs | 26 +- .../src/engine/synthesis_engine.rs | 54 +- crates/voicevox_core/src/error.rs | 3 +- crates/voicevox_core/src/infer.rs | 188 +++++++ crates/voicevox_core/src/infer/domain.rs | 86 +++ .../src/{status => infer}/model_file.rs | 0 crates/voicevox_core/src/infer/runtimes.rs | 3 + .../src/infer/runtimes/onnxruntime.rs | 239 +++++++++ crates/voicevox_core/src/infer/status.rs | 408 ++++++++++++++ crates/voicevox_core/src/inference_core.rs | 122 +++-- crates/voicevox_core/src/lib.rs | 2 +- crates/voicevox_core/src/status.rs | 504 ------------------ crates/voicevox_core/src/synthesizer.rs | 14 +- crates/voicevox_core/src/voice_model.rs | 23 +- crates/voicevox_core_c_api/Cargo.toml | 2 +- crates/voicevox_core_macros/Cargo.toml | 15 + .../src/inference_domain.rs | 379 +++++++++++++ crates/voicevox_core_macros/src/lib.rs | 104 ++++ 21 files changed, 1610 insertions(+), 639 deletions(-) create mode 100644 crates/voicevox_core/src/infer.rs create mode 100644 crates/voicevox_core/src/infer/domain.rs rename crates/voicevox_core/src/{status => infer}/model_file.rs (100%) create mode 100644 crates/voicevox_core/src/infer/runtimes.rs create mode 100644 crates/voicevox_core/src/infer/runtimes/onnxruntime.rs create mode 100644 crates/voicevox_core/src/infer/status.rs delete mode 100644 crates/voicevox_core/src/status.rs create mode 100644 crates/voicevox_core_macros/Cargo.toml create mode 100644 crates/voicevox_core_macros/src/inference_domain.rs create mode 100644 crates/voicevox_core_macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 50868f63b..c7c85b43b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1205,6 +1205,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49457524c7e65648794c98283282a0b7c73b10018e7091f1cdcfff314fd7ae59" +[[package]] +name = "educe" +version = "0.4.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 1.0.102", +] + [[package]] name = "either" version = "1.8.0" @@ -1226,6 +1238,39 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-map" +version = "3.0.0-beta.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e698c4fb1d30d2aeaf3b169ca72fbc019a049d7c85acc7f91d5f58a22e3ee13" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "1.0.0-0.gat.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c69b3965971f5d0ea6a6dd26b55cdd517ae0e1425dc8d94e482a5915bd7ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "enum-ordinalize" +version = "3.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf1fa3f06bbff1ea5b1a9c7b14aa992a39657db60a2759457328d7e058f49ee" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "env_logger" version = "0.9.1" @@ -4273,6 +4318,8 @@ dependencies = [ "derive_more", "duplicate", "easy-ext", + "educe", + "enum-map", "fs-err", "futures", "heck", @@ -4280,6 +4327,7 @@ dependencies = [ "indexmap 2.0.0", "itertools", "nanoid", + "ndarray", "once_cell", "onnxruntime", "open_jtalk", @@ -4294,6 +4342,7 @@ dependencies = [ "tokio", "tracing", "uuid", + "voicevox_core_macros", "windows", ] @@ -4355,6 +4404,16 @@ dependencies = [ "voicevox_core", ] +[[package]] +name = "voicevox_core_macros" +version = "0.0.0" +dependencies = [ + "indexmap 2.0.0", + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "voicevox_core_python_api" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index b6237098a..65353e6ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,5 @@ [workspace] -members = [ - "crates/downloader", - "crates/test_util", - "crates/voicevox_core", - "crates/voicevox_core_c_api", - "crates/voicevox_core_java_api", - "crates/voicevox_core_python_api", - "crates/xtask" -] +members = ["crates/*"] resolver = "2" [workspace.dependencies] @@ -18,7 +10,9 @@ derive_more = "0.99.17" easy-ext = "1.0.1" fs-err = { version = "2.9.0", features = ["tokio"] } futures = "0.3.26" +indexmap = { version = "2.0.0", features = ["serde"] } itertools = "0.10.5" +ndarray = "0.15.6" once_cell = "1.18.0" regex = "1.10.0" rstest = "0.15.0" diff --git a/crates/voicevox_core/Cargo.toml b/crates/voicevox_core/Cargo.toml index 3a23b794a..763b69605 100644 --- a/crates/voicevox_core/Cargo.toml +++ b/crates/voicevox_core/Cargo.toml @@ -17,11 +17,14 @@ derive-new = "0.5.9" derive_more.workspace = true duplicate = "1.0.0" easy-ext.workspace = true +educe = "0.4.23" +enum-map = "3.0.0-beta.1" fs-err.workspace = true futures.workspace = true -indexmap = { version = "2.0.0", features = ["serde"] } +indexmap.workspace = true itertools.workspace = true nanoid = "0.4.0" +ndarray.workspace = true once_cell.workspace = true regex.workspace = true serde.workspace = true @@ -31,6 +34,7 @@ thiserror.workspace = true tokio.workspace = true tracing.workspace = true uuid.workspace = true +voicevox_core_macros = { path = "../voicevox_core_macros" } [dependencies.onnxruntime] git = "https://github.com/VOICEVOX/onnxruntime-rs.git" diff --git a/crates/voicevox_core/src/devices.rs b/crates/voicevox_core/src/devices.rs index 70847cb81..545b5e485 100644 --- a/crates/voicevox_core/src/devices.rs +++ b/crates/voicevox_core/src/devices.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use super::*; +use crate::{infer::InferenceRuntime, synthesizer::InferenceRuntimeImpl}; /// このライブラリで利用可能なデバイスの情報。 /// @@ -11,21 +12,21 @@ pub struct SupportedDevices { /// CPUが利用可能。 /// /// 常に`true`。 - cpu: bool, + pub cpu: bool, /// CUDAが利用可能。 /// /// ONNX Runtimeの[CUDA Execution Provider] (`CUDAExecutionProvider`)に対応する。必要な環境につ /// いてはそちらを参照。 /// /// [CUDA Execution Provider]: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html - cuda: bool, + pub cuda: bool, /// DirectMLが利用可能。 /// /// ONNX Runtimeの[DirectML Execution Provider] (`DmlExecutionProvider`)に対応する。必要な環境に /// ついてはそちらを参照。 /// /// [DirectML Execution Provider]: https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html - dml: bool, + pub dml: bool, } impl SupportedDevices { @@ -42,24 +43,7 @@ impl SupportedDevices { /// # Result::<_, anyhow::Error>::Ok(()) /// ``` pub fn create() -> Result { - let mut cuda_support = false; - let mut dml_support = false; - for provider in onnxruntime::session::get_available_providers() - .map_err(ErrorRepr::GetSupportedDevices)? - .iter() - { - match provider.as_str() { - "CUDAExecutionProvider" => cuda_support = true, - "DmlExecutionProvider" => dml_support = true, - _ => {} - } - } - - Ok(SupportedDevices { - cpu: true, - cuda: cuda_support, - dml: dml_support, - }) + ::supported_devices() } pub fn to_json(&self) -> serde_json::Value { diff --git a/crates/voicevox_core/src/engine/synthesis_engine.rs b/crates/voicevox_core/src/engine/synthesis_engine.rs index 22ced6f84..c70742f16 100644 --- a/crates/voicevox_core/src/engine/synthesis_engine.rs +++ b/crates/voicevox_core/src/engine/synthesis_engine.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use super::full_context_label::Utterance; use super::open_jtalk::OpenJtalk; use super::*; +use crate::infer::InferenceRuntime; use crate::numerics::F32Ext as _; use crate::InferenceCore; @@ -14,19 +15,16 @@ const MORA_PHONEME_LIST: &[&str] = &[ "a", "i", "u", "e", "o", "N", "A", "I", "U", "E", "O", "cl", "pau", ]; +pub const DEFAULT_SAMPLING_RATE: u32 = 24000; + #[derive(new)] -pub struct SynthesisEngine { - inference_core: InferenceCore, +pub(crate) struct SynthesisEngine { + inference_core: InferenceCore, open_jtalk: Arc, } -#[allow(unsafe_code)] -unsafe impl Send for SynthesisEngine {} - -impl SynthesisEngine { - pub const DEFAULT_SAMPLING_RATE: u32 = 24000; - - pub fn inference_core(&self) -> &InferenceCore { +impl SynthesisEngine { + pub fn inference_core(&self) -> &InferenceCore { &self.inference_core } @@ -123,7 +121,7 @@ impl SynthesisEngine { accent_phrases: &[AccentPhraseModel], style_id: StyleId, ) -> Result> { - let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases); + let (_, phoneme_data_list) = Self::initial_process(accent_phrases); let (_, _, vowel_indexes_data) = split_mora(&phoneme_data_list); @@ -185,7 +183,7 @@ impl SynthesisEngine { accent_phrases: &[AccentPhraseModel], style_id: StyleId, ) -> Result> { - let (_, phoneme_data_list) = SynthesisEngine::initial_process(accent_phrases); + let (_, phoneme_data_list) = Self::initial_process(accent_phrases); let mut base_start_accent_list = vec![0]; let mut base_end_accent_list = vec![0]; @@ -193,28 +191,12 @@ impl SynthesisEngine { let mut base_end_accent_phrase_list = vec![0]; for accent_phrase in accent_phrases { let mut accent = usize::from(*accent_phrase.accent() != 1); - SynthesisEngine::create_one_accent_list( - &mut base_start_accent_list, - accent_phrase, - accent as i32, - ); + Self::create_one_accent_list(&mut base_start_accent_list, accent_phrase, accent as i32); accent = *accent_phrase.accent() - 1; - SynthesisEngine::create_one_accent_list( - &mut base_end_accent_list, - accent_phrase, - accent as i32, - ); - SynthesisEngine::create_one_accent_list( - &mut base_start_accent_phrase_list, - accent_phrase, - 0, - ); - SynthesisEngine::create_one_accent_list( - &mut base_end_accent_phrase_list, - accent_phrase, - -1, - ); + Self::create_one_accent_list(&mut base_end_accent_list, accent_phrase, accent as i32); + Self::create_one_accent_list(&mut base_start_accent_phrase_list, accent_phrase, 0); + Self::create_one_accent_list(&mut base_end_accent_phrase_list, accent_phrase, -1); } base_start_accent_list.push(0); base_end_accent_list.push(0); @@ -328,7 +310,7 @@ impl SynthesisEngine { query.accent_phrases().clone() }; - let (flatten_moras, phoneme_data_list) = SynthesisEngine::initial_process(&accent_phrases); + let (flatten_moras, phoneme_data_list) = Self::initial_process(&accent_phrases); let mut phoneme_length_list = vec![pre_phoneme_length]; let mut f0_list = vec![0.]; @@ -440,7 +422,7 @@ impl SynthesisEngine { let num_channels: u16 = if output_stereo { 2 } else { 1 }; let bit_depth: u16 = 16; let repeat_count: u32 = - (output_sampling_rate / Self::DEFAULT_SAMPLING_RATE) * num_channels as u32; + (output_sampling_rate / DEFAULT_SAMPLING_RATE) * num_channels as u32; let block_size: u16 = bit_depth * num_channels / 8; let bytes_size = wave.len() as u32 * repeat_count * 2; @@ -647,12 +629,12 @@ mod tests { use ::test_util::OPEN_JTALK_DIC_DIR; use pretty_assertions::assert_eq; - use crate::*; + use crate::{synthesizer::InferenceRuntimeImpl, *}; #[rstest] #[tokio::test] async fn is_openjtalk_dict_loaded_works() { - let core = InferenceCore::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let synthesis_engine = SynthesisEngine::new(core, OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap().into()); @@ -662,7 +644,7 @@ mod tests { #[rstest] #[tokio::test] async fn create_accent_phrases_works() { - let core = InferenceCore::new(false, 0).unwrap(); + let core = InferenceCore::::new(false, 0).unwrap(); let model = &VoiceModel::sample().await.unwrap(); core.load_model(model).await.unwrap(); diff --git a/crates/voicevox_core/src/error.rs b/crates/voicevox_core/src/error.rs index 44451ece5..043b51991 100644 --- a/crates/voicevox_core/src/error.rs +++ b/crates/voicevox_core/src/error.rs @@ -2,7 +2,6 @@ use self::engine::{FullContextLabelError, KanaParseError}; use super::*; //use engine:: use duplicate::duplicate_item; -use onnxruntime::OrtError; use std::path::PathBuf; use thiserror::Error; use uuid::Uuid; @@ -65,7 +64,7 @@ pub(crate) enum ErrorRepr { LoadModel(#[from] LoadModelError), #[error("サポートされているデバイス情報取得中にエラーが発生しました")] - GetSupportedDevices(#[source] OrtError), + GetSupportedDevices(#[source] anyhow::Error), #[error( "`{style_id}`に対するスタイルが見つかりませんでした。音声モデルが読み込まれていないか、読\ diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs new file mode 100644 index 000000000..c6b81348a --- /dev/null +++ b/crates/voicevox_core/src/infer.rs @@ -0,0 +1,188 @@ +pub(crate) mod domain; +mod model_file; +pub(crate) mod runtimes; +pub(crate) mod status; + +use std::{borrow::Cow, fmt::Debug}; + +use derive_new::new; +use enum_map::{Enum, EnumMap}; +use ndarray::{Array, ArrayD, Dimension, ShapeError}; +use thiserror::Error; + +use crate::SupportedDevices; + +pub(crate) trait InferenceRuntime: 'static { + type Session: Sized + Send + 'static; + type RunContext<'a>: From<&'a mut Self::Session>; + + fn supported_devices() -> crate::Result; + + #[allow(clippy::type_complexity)] + fn new_session( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: InferenceSessionOptions, + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )>; + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ); + + fn run(ctx: Self::RunContext<'_>) -> anyhow::Result>; +} + +/// ある`VoiceModel`が提供する推論操作の集合を示す。 +pub(crate) trait InferenceDomain { + type Operation: InferenceOperation; +} + +/// `InferenceDomain`の推論操作を表す列挙型。 +/// +/// それぞれのバリアントには、対応する`InferenceSignature`が存在するべきである。 +/// +/// `::macros::InferenceOperation`により導出される。 +pub(crate) trait InferenceOperation: Copy + Enum { + /// `{InferenceInputSignature,InferenceOutputSignature}::PARAM_INFOS`を集めたもの。 + #[allow(clippy::type_complexity)] + const PARAM_INFOS: EnumMap< + Self, + ( + &'static [ParamInfo], + &'static [ParamInfo], + ), + >; +} + +/// `InferenceDomain`の推論操作を表す列挙型。 +/// +/// `::macros::InferenceOperation`により、具体型ごと生成される。 +pub(crate) trait InferenceSignature: Sized + Send + 'static { + type Domain: InferenceDomain; + type Input: InferenceInputSignature; + type Output: InferenceOutputSignature; + const OPERATION: ::Operation; +} + +/// 推論操作の入力シグネチャ。 +/// +/// `::macros::InferenceInputSignature`により導出される。 +pub(crate) trait InferenceInputSignature: Send + 'static { + type Signature: InferenceSignature; + const PARAM_INFOS: &'static [ParamInfo]; + fn make_run_context(self, sess: &mut R::Session) -> R::RunContext<'_>; +} + +pub(crate) trait InputScalar: sealed::InputScalar + Debug + 'static { + const KIND: InputScalarKind; +} + +impl InputScalar for i64 { + const KIND: InputScalarKind = InputScalarKind::Int64; +} + +impl InputScalar for f32 { + const KIND: InputScalarKind = InputScalarKind::Float32; +} + +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum InputScalarKind { + #[display(fmt = "int64_t")] + Int64, + + #[display(fmt = "float")] + Float32, +} + +/// 推論操作の出力シグネチャ。 +/// +/// `::macros::InferenceOutputSignature`により、`TryFrom`も含めて導出される。 +pub(crate) trait InferenceOutputSignature: + TryFrom, Error = anyhow::Error> + Send +{ + const PARAM_INFOS: &'static [ParamInfo]; +} + +pub(crate) trait OutputScalar: Sized { + const KIND: OutputScalarKind; + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError>; +} + +impl OutputScalar for f32 { + const KIND: OutputScalarKind = OutputScalarKind::Float32; + + fn extract(tensor: OutputTensor) -> std::result::Result, ExtractError> { + match tensor { + OutputTensor::Float32(tensor) => Ok(tensor), + } + } +} + +#[derive(Clone, Copy, PartialEq, derive_more::Display)] +pub(crate) enum OutputScalarKind { + #[display(fmt = "float")] + Float32, +} + +pub(crate) enum OutputTensor { + Float32(ArrayD), +} + +impl TryFrom for Array { + type Error = ExtractError; + + fn try_from(tensor: OutputTensor) -> Result { + let this = A::extract(tensor)?.into_dimensionality()?; + Ok(this) + } +} + +pub(crate) struct ParamInfo { + name: Cow<'static, str>, + dt: D, + ndim: Option, +} + +impl ParamInfo { + fn accepts(&self, other: &Self) -> bool { + self.name == other.name + && self.dt == other.dt + && (self.ndim.is_none() || self.ndim == other.ndim) + } +} + +#[derive(new, Clone, Copy, PartialEq, Debug)] +pub(crate) struct InferenceSessionOptions { + pub(crate) cpu_num_threads: u16, + pub(crate) use_gpu: bool, +} + +#[derive(Error, Debug)] +pub(crate) enum ExtractError { + #[error(transparent)] + Shape(#[from] ShapeError), +} + +#[derive(Error, Debug)] +#[error("不正なモデルファイルです")] +pub(crate) struct DecryptModelError; + +// FIXME: `onnxruntime::TypeToTensorElementDataType`に依存する代わりに、`InputScalar`から`runtimes` +// まではvisitor patternでつなぐ +mod sealed { + pub(crate) trait InputScalar: OnnxruntimeInputScalar {} + + impl InputScalar for i64 {} + impl InputScalar for f32 {} + + pub(crate) trait OnnxruntimeInputScalar: + onnxruntime::TypeToTensorElementDataType + { + } + + impl OnnxruntimeInputScalar for T {} +} diff --git a/crates/voicevox_core/src/infer/domain.rs b/crates/voicevox_core/src/infer/domain.rs new file mode 100644 index 000000000..bb83886dd --- /dev/null +++ b/crates/voicevox_core/src/infer/domain.rs @@ -0,0 +1,86 @@ +use enum_map::Enum; +use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature}; +use ndarray::{Array0, Array1, Array2}; + +use super::{ + InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor, +}; + +pub(crate) enum InferenceDomainImpl {} + +impl InferenceDomain for InferenceDomainImpl { + type Operation = InferenceOperationImpl; +} + +#[derive(Clone, Copy, Enum, InferenceOperation)] +#[inference_operation( + type Domain = InferenceDomainImpl; +)] +pub(crate) enum InferenceOperationImpl { + #[inference_operation( + type Input = PredictDurationInput; + type Output = PredictDurationOutput; + )] + PredictDuration, + + #[inference_operation( + type Input = PredictIntonationInput; + type Output = PredictIntonationOutput; + )] + PredictIntonation, + + #[inference_operation( + type Input = DecodeInput; + type Output = DecodeOutput; + )] + Decode, +} + +#[derive(InferenceInputSignature)] +#[inference_input_signature( + type Signature = PredictDuration; +)] +pub(crate) struct PredictDurationInput { + pub(crate) phoneme_list: Array1, + pub(crate) speaker_id: Array1, +} + +#[derive(InferenceOutputSignature)] +pub(crate) struct PredictDurationOutput { + pub(crate) phoneme_length: Array1, +} + +#[derive(InferenceInputSignature)] +#[inference_input_signature( + type Signature = PredictIntonation; +)] +pub(crate) struct PredictIntonationInput { + pub(crate) length: Array0, + pub(crate) vowel_phoneme_list: Array1, + pub(crate) consonant_phoneme_list: Array1, + pub(crate) start_accent_list: Array1, + pub(crate) end_accent_list: Array1, + pub(crate) start_accent_phrase_list: Array1, + pub(crate) end_accent_phrase_list: Array1, + pub(crate) speaker_id: Array1, +} + +#[derive(InferenceOutputSignature)] +pub(crate) struct PredictIntonationOutput { + pub(crate) f0_list: Array1, +} + +#[derive(InferenceInputSignature)] +#[inference_input_signature( + type Signature = Decode; +)] +pub(crate) struct DecodeInput { + pub(crate) f0: Array2, + pub(crate) phoneme: Array2, + pub(crate) speaker_id: Array1, +} + +#[derive(InferenceOutputSignature)] +pub(crate) struct DecodeOutput { + pub(crate) wave: Array1, +} diff --git a/crates/voicevox_core/src/status/model_file.rs b/crates/voicevox_core/src/infer/model_file.rs similarity index 100% rename from crates/voicevox_core/src/status/model_file.rs rename to crates/voicevox_core/src/infer/model_file.rs diff --git a/crates/voicevox_core/src/infer/runtimes.rs b/crates/voicevox_core/src/infer/runtimes.rs new file mode 100644 index 000000000..7934027b6 --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes.rs @@ -0,0 +1,3 @@ +mod onnxruntime; + +pub(crate) use self::onnxruntime::Onnxruntime; diff --git a/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs new file mode 100644 index 000000000..ca5b28aaa --- /dev/null +++ b/crates/voicevox_core/src/infer/runtimes/onnxruntime.rs @@ -0,0 +1,239 @@ +use std::{fmt::Debug, vec}; + +use anyhow::anyhow; +use ndarray::{Array, Dimension}; +use once_cell::sync::Lazy; +use onnxruntime::{ + environment::Environment, GraphOptimizationLevel, LoggingLevel, TensorElementDataType, +}; + +use crate::{devices::SupportedDevices, error::ErrorRepr}; + +use self::assert_send::AssertSend; + +use super::super::{ + DecryptModelError, InferenceRuntime, InferenceSessionOptions, InputScalar, InputScalarKind, + OutputScalarKind, OutputTensor, ParamInfo, +}; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub(crate) enum Onnxruntime {} + +impl InferenceRuntime for Onnxruntime { + type Session = AssertSend>; + type RunContext<'a> = OnnxruntimeRunContext<'a>; + + fn supported_devices() -> crate::Result { + let mut cuda_support = false; + let mut dml_support = false; + for provider in onnxruntime::session::get_available_providers() + .map_err(Into::into) + .map_err(ErrorRepr::GetSupportedDevices)? + .iter() + { + match provider.as_str() { + "CUDAExecutionProvider" => cuda_support = true, + "DmlExecutionProvider" => dml_support = true, + _ => {} + } + } + + Ok(SupportedDevices { + cpu: true, + cuda: cuda_support, + dml: dml_support, + }) + } + + fn new_session( + model: impl FnOnce() -> std::result::Result, DecryptModelError>, + options: InferenceSessionOptions, + ) -> anyhow::Result<( + Self::Session, + Vec>, + Vec>, + )> { + let mut builder = ENVIRONMENT + .new_session_builder()? + .with_optimization_level(GraphOptimizationLevel::Basic)? + .with_intra_op_num_threads(options.cpu_num_threads.into())? + .with_inter_op_num_threads(options.cpu_num_threads.into())?; + + if options.use_gpu { + #[cfg(feature = "directml")] + { + use onnxruntime::ExecutionMode; + + builder = builder + .with_disable_mem_pattern()? + .with_execution_mode(ExecutionMode::ORT_SEQUENTIAL)? + .with_append_execution_provider_directml(0)?; + } + + #[cfg(not(feature = "directml"))] + { + builder = builder.with_append_execution_provider_cuda(Default::default())?; + } + } + + let model = model()?; + let sess = AssertSend::from(builder.with_model_from_memory(model)?); + + let input_param_infos = sess + .inputs + .iter() + .map(|info| { + let dt = match info.input_type { + TensorElementDataType::Float => Ok(InputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Ok(InputScalarKind::Int64), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported input datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + let output_param_infos = sess + .outputs + .iter() + .map(|info| { + let dt = match info.output_type { + TensorElementDataType::Float => Ok(OutputScalarKind::Float32), + TensorElementDataType::Uint8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8"), + TensorElementDataType::Int8 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8"), + TensorElementDataType::Uint16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16"), + TensorElementDataType::Int16 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16"), + TensorElementDataType::Int32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32"), + TensorElementDataType::Int64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64"), + TensorElementDataType::String => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING"), + TensorElementDataType::Double => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE"), + TensorElementDataType::Uint32 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32"), + TensorElementDataType::Uint64 => Err("ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64"), + } + .map_err(|actual| { + anyhow!("unsupported output datatype `{actual}` for `{}`", info.name) + })?; + + Ok(ParamInfo { + name: info.name.clone().into(), + dt, + ndim: Some(info.dimensions.len()), + }) + }) + .collect::>()?; + + return Ok((sess, input_param_infos, output_param_infos)); + + static ENVIRONMENT: Lazy = Lazy::new(|| { + Environment::builder() + .with_name(env!("CARGO_PKG_NAME")) + .with_log_level(LOGGING_LEVEL) + .build() + .unwrap() + }); + + const LOGGING_LEVEL: LoggingLevel = if cfg!(debug_assertions) { + LoggingLevel::Verbose + } else { + LoggingLevel::Warning + }; + } + + fn push_input( + input: Array, + ctx: &mut Self::RunContext<'_>, + ) { + ctx.inputs + .push(Box::new(onnxruntime::session::NdArray::new(input))); + } + + fn run( + OnnxruntimeRunContext { sess, mut inputs }: OnnxruntimeRunContext<'_>, + ) -> anyhow::Result> { + // FIXME: 現状では`f32`のみ対応。実行時にsessionからdatatypeが取れるので、別の型の対応も + // おそらく可能ではあるが、それが必要になるよりもortクレートへの引越しが先になると思われる + // のでこのままにする。 + + if !sess + .outputs + .iter() + .all(|info| matches!(info.output_type, TensorElementDataType::Float)) + { + unimplemented!( + "currently only `ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT` is supported for output", + ); + } + + let outputs = sess.run::(inputs.iter_mut().map(|t| &mut **t as &mut _).collect())?; + + Ok(outputs + .iter() + .map(|o| OutputTensor::Float32((*o).clone().into_owned())) + .collect()) + } +} + +pub(crate) struct OnnxruntimeRunContext<'sess> { + sess: &'sess mut AssertSend>, + inputs: Vec>, +} + +impl<'sess> From<&'sess mut AssertSend>> + for OnnxruntimeRunContext<'sess> +{ + fn from(sess: &'sess mut AssertSend>) -> Self { + Self { + sess, + inputs: vec![], + } + } +} + +// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 +// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 +mod assert_send { + use std::ops::{Deref, DerefMut}; + + pub(crate) struct AssertSend(T); + + impl From> + for AssertSend> + { + fn from(session: onnxruntime::session::Session<'static>) -> Self { + Self(session) + } + } + + impl Deref for AssertSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for AssertSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + // SAFETY: `Session` is probably "send"able. + #[allow(unsafe_code)] + unsafe impl Send for AssertSend {} +} diff --git a/crates/voicevox_core/src/infer/status.rs b/crates/voicevox_core/src/infer/status.rs new file mode 100644 index 000000000..7903cb8ff --- /dev/null +++ b/crates/voicevox_core/src/infer/status.rs @@ -0,0 +1,408 @@ +use std::{ + collections::{BTreeMap, HashMap}, + fmt::Display, + marker::PhantomData, + sync::Arc, +}; + +use anyhow::bail; +use educe::Educe; +use enum_map::{Enum as _, EnumMap}; +use itertools::{iproduct, Itertools as _}; + +use crate::{ + error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, + infer::{InferenceOperation, ParamInfo}, + manifest::ModelInnerId, + metas::{SpeakerMeta, StyleId, StyleMeta, VoiceModelMeta}, + voice_model::{VoiceModel, VoiceModelId}, + Result, +}; + +use super::{ + model_file, InferenceDomain, InferenceInputSignature, InferenceRuntime, + InferenceSessionOptions, InferenceSignature, +}; + +pub(crate) struct Status { + loaded_models: std::sync::Mutex>, + session_options: EnumMap, +} + +impl Status { + pub fn new(session_options: EnumMap) -> Self { + Self { + loaded_models: Default::default(), + session_options, + } + } + + pub async fn load_model( + &self, + model: &VoiceModel, + model_bytes: &EnumMap>, + ) -> Result<()> { + self.loaded_models + .lock() + .unwrap() + .ensure_acceptable(model)?; + + let session_set = + SessionSet::new(model_bytes, &self.session_options).map_err(|source| { + LoadModelError { + path: model.path().clone(), + context: LoadModelErrorKind::InvalidModelData, + source: Some(source), + } + })?; + + self.loaded_models + .lock() + .unwrap() + .insert(model, session_set)?; + Ok(()) + } + + pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { + self.loaded_models.lock().unwrap().remove(voice_model_id) + } + + pub fn metas(&self) -> VoiceModelMeta { + self.loaded_models.lock().unwrap().metas() + } + + pub(crate) fn ids_for(&self, style_id: StyleId) -> Result<(VoiceModelId, ModelInnerId)> { + self.loaded_models.lock().unwrap().ids_for(style_id) + } + + pub fn is_loaded_model(&self, voice_model_id: &VoiceModelId) -> bool { + self.loaded_models + .lock() + .unwrap() + .contains_voice_model(voice_model_id) + } + + pub fn is_loaded_model_by_style_id(&self, style_id: StyleId) -> bool { + self.loaded_models.lock().unwrap().contains_style(style_id) + } + + pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { + self.is_loaded_model_by_style_id(style_id) + } + + /// # Panics + /// + /// `self`が`model_id`を含んでいないとき、パニックする。 + pub(crate) async fn run_session( + &self, + model_id: &VoiceModelId, + input: I, + ) -> Result<::Output> + where + I: InferenceInputSignature, + I::Signature: InferenceSignature, + { + let sess = self.loaded_models.lock().unwrap().get(model_id); + + tokio::task::spawn_blocking(move || sess.run(input)) + .await + .unwrap() + } +} + +/// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 +/// +/// この構造体のメソッドは、すべて一瞬で完了すべきである。 +#[derive(Educe)] +#[educe(Default(bound = "R: InferenceRuntime, D: InferenceDomain"))] +struct LoadedModels( + BTreeMap>, +); + +struct LoadedModel { + model_inner_ids: BTreeMap, + metas: VoiceModelMeta, + session_set: SessionSet, +} + +impl LoadedModels { + fn metas(&self) -> VoiceModelMeta { + self.0 + .values() + .flat_map(|LoadedModel { metas, .. }| metas) + .cloned() + .collect() + } + + fn ids_for(&self, style_id: StyleId) -> Result<(VoiceModelId, ModelInnerId)> { + let ( + model_id, + LoadedModel { + model_inner_ids, .. + }, + ) = self + .0 + .iter() + .find(|(_, LoadedModel { metas, .. })| { + metas + .iter() + .flat_map(SpeakerMeta::styles) + .any(|style| *style.id() == style_id) + }) + .ok_or(ErrorRepr::StyleNotFound { style_id })?; + + let model_inner_id = *model_inner_ids + .get(&style_id) + .expect("`model_inner_ids` should contains all of the style IDs in the model"); + + Ok((model_id.clone(), model_inner_id)) + } + + /// # Panics + /// + /// `self`が`model_id`を含んでいないとき、パニックする。 + fn get(&self, model_id: &VoiceModelId) -> SessionCell + where + I: InferenceInputSignature, + I::Signature: InferenceSignature, + { + self.0[model_id].session_set.get() + } + + fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { + self.0.contains_key(model_id) + } + + fn contains_style(&self, style_id: StyleId) -> bool { + self.styles().any(|style| *style.id() == style_id) + } + + /// 与えられた`VoiceModel`を受け入れ可能かをチェックする。 + /// + /// # Errors + /// + /// 音声モデルIDかスタイルIDが`model`と重複するとき、エラーを返す。 + fn ensure_acceptable(&self, model: &VoiceModel) -> LoadModelResult<()> { + let loaded = self.styles(); + let external = model.metas().iter().flat_map(|speaker| speaker.styles()); + + let error = |context| LoadModelError { + path: model.path().clone(), + context, + source: None, + }; + + if self.0.contains_key(model.id()) { + return Err(error(LoadModelErrorKind::ModelAlreadyLoaded { + id: model.id().clone(), + })); + } + if let Some((style, _)) = + iproduct!(loaded, external).find(|(loaded, external)| loaded.id() == external.id()) + { + return Err(error(LoadModelErrorKind::StyleAlreadyLoaded { + id: *style.id(), + })); + } + Ok(()) + } + + fn insert(&mut self, model: &VoiceModel, session_set: SessionSet) -> Result<()> { + self.ensure_acceptable(model)?; + + let prev = self.0.insert( + model.id().clone(), + LoadedModel { + model_inner_ids: model.model_inner_ids(), + metas: model.metas().clone(), + session_set, + }, + ); + assert!(prev.is_none()); + Ok(()) + } + + fn remove(&mut self, model_id: &VoiceModelId) -> Result<()> { + if self.0.remove(model_id).is_none() { + return Err(ErrorRepr::ModelNotFound { + model_id: model_id.clone(), + } + .into()); + } + Ok(()) + } + + fn styles(&self) -> impl Iterator { + self.0 + .values() + .flat_map(|LoadedModel { metas, .. }| metas) + .flat_map(|speaker| speaker.styles()) + } +} + +struct SessionSet( + EnumMap>>, +); + +impl SessionSet { + fn new( + model_bytes: &EnumMap>, + options: &EnumMap, + ) -> anyhow::Result { + let mut sessions = model_bytes + .iter() + .map(|(op, model_bytes)| { + let (expected_input_param_infos, expected_output_param_infos) = + ::PARAM_INFOS[op]; + + let (sess, actual_input_param_infos, actual_output_param_infos) = + R::new_session(|| model_file::decrypt(model_bytes), options[op])?; + + check_param_infos(expected_input_param_infos, &actual_input_param_infos)?; + check_param_infos(expected_output_param_infos, &actual_output_param_infos)?; + + Ok((op.into_usize(), std::sync::Mutex::new(sess).into())) + }) + .collect::>>()?; + + return Ok(Self(EnumMap::::from_fn(|k| { + sessions.remove(&k.into_usize()).expect("should exist") + }))); + + fn check_param_infos( + expected: &[ParamInfo], + actual: &[ParamInfo], + ) -> anyhow::Result<()> { + if !(expected.len() == actual.len() + && itertools::zip_eq(expected, actual) + .all(|(expected, actual)| expected.accepts(actual))) + { + let expected = display_param_infos(expected); + let actual = display_param_infos(actual); + bail!("expected {{{expected}}}, got {{{actual}}}") + } + Ok(()) + } + + fn display_param_infos(infos: &[ParamInfo]) -> impl Display { + infos + .iter() + .map(|ParamInfo { name, dt, ndim }| { + let brackets = match *ndim { + Some(ndim) => "[]".repeat(ndim), + None => "[]...".to_owned(), + }; + format!("{name}: {dt}{brackets}") + }) + .join(", ") + } + } +} + +impl SessionSet { + fn get(&self) -> SessionCell + where + I: InferenceInputSignature, + I::Signature: InferenceSignature, + { + SessionCell { + inner: self.0[I::Signature::OPERATION].clone(), + marker: PhantomData, + } + } +} + +struct SessionCell { + inner: Arc>, + marker: PhantomData, +} + +impl SessionCell { + fn run(self, input: I) -> crate::Result<::Output> { + let inner = &mut self.inner.lock().unwrap(); + let ctx = input.make_run_context::(inner); + R::run(ctx) + .and_then(TryInto::try_into) + .map_err(|e| ErrorRepr::InferenceFailed(e).into()) + } +} + +#[cfg(test)] +mod tests { + use enum_map::enum_map; + use pretty_assertions::assert_eq; + use rstest::rstest; + + use crate::{ + infer::domain::{InferenceDomainImpl, InferenceOperationImpl}, + macros::tests::assert_debug_fmt_eq, + synthesizer::InferenceRuntimeImpl, + test_util::open_default_vvm_file, + }; + + use super::{super::InferenceSessionOptions, Status}; + + #[rstest] + #[case(true, 0)] + #[case(true, 1)] + #[case(true, 8)] + #[case(false, 2)] + #[case(false, 4)] + #[case(false, 8)] + #[case(false, 0)] + fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { + let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); + let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); + let session_options = enum_map! { + InferenceOperationImpl::PredictDuration + | InferenceOperationImpl::PredictIntonation => light_session_options, + InferenceOperationImpl::Decode => heavy_session_options, + }; + let status = Status::::new(session_options); + + assert_eq!( + light_session_options, + status.session_options[InferenceOperationImpl::PredictDuration], + ); + assert_eq!( + light_session_options, + status.session_options[InferenceOperationImpl::PredictIntonation], + ); + assert_eq!( + heavy_session_options, + status.session_options[InferenceOperationImpl::Decode], + ); + + assert!(status.loaded_models.lock().unwrap().0.is_empty()); + } + + #[rstest] + #[tokio::test] + async fn status_load_model_works() { + let status = Status::::new( + enum_map!(_ => InferenceSessionOptions::new(0, false)), + ); + let model = &open_default_vvm_file().await; + let model_bytes = &model.read_inference_models().await.unwrap(); + let result = status.load_model(model, model_bytes).await; + assert_debug_fmt_eq!(Ok(()), result); + assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); + } + + #[rstest] + #[tokio::test] + async fn status_is_model_loaded_works() { + let status = Status::::new( + enum_map!(_ => InferenceSessionOptions::new(0, false)), + ); + let vvm = open_default_vvm_file().await; + let model_bytes = &vvm.read_inference_models().await.unwrap(); + assert!( + !status.is_loaded_model(vvm.id()), + "model should not be loaded" + ); + let result = status.load_model(&vvm, model_bytes).await; + assert_debug_fmt_eq!(Ok(()), result); + assert!(status.is_loaded_model(vvm.id()), "model should be loaded"); + } +} diff --git a/crates/voicevox_core/src/inference_core.rs b/crates/voicevox_core/src/inference_core.rs index 4b0d08be2..875c9ba64 100644 --- a/crates/voicevox_core/src/inference_core.rs +++ b/crates/voicevox_core/src/inference_core.rs @@ -1,17 +1,37 @@ -use self::status::*; +use enum_map::enum_map; + +use crate::infer::{ + domain::{ + DecodeInput, DecodeOutput, InferenceDomainImpl, InferenceOperationImpl, + PredictDurationInput, PredictDurationOutput, PredictIntonationInput, + PredictIntonationOutput, + }, + status::Status, + InferenceRuntime, InferenceSessionOptions, +}; + use super::*; -use onnxruntime::{ndarray, session::NdArray}; const PHONEME_LENGTH_MINIMAL: f32 = 0.01; -pub struct InferenceCore { - status: Status, +pub(crate) struct InferenceCore { + status: Status, } -impl InferenceCore { +impl InferenceCore { pub(crate) fn new(use_gpu: bool, cpu_num_threads: u16) -> Result { if !use_gpu || Self::can_support_gpu_feature()? { - let status = Status::new(use_gpu, cpu_num_threads); + // 軽いモデルはこちらを使う + let light_session_options = InferenceSessionOptions::new(cpu_num_threads, false); + + // 重いモデルはこちらを使う + let heavy_session_options = InferenceSessionOptions::new(cpu_num_threads, use_gpu); + + let status = Status::new(enum_map! { + InferenceOperationImpl::PredictDuration + | InferenceOperationImpl::PredictIntonation => light_session_options, + InferenceOperationImpl::Decode => heavy_session_options, + }); Ok(Self { status }) } else { Err(ErrorRepr::GpuSupport.into()) @@ -31,7 +51,8 @@ impl InferenceCore { } pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { - self.status.load_model(model).await + let model_bytes = &model.read_inference_models().await?; + self.status.load_model(model, model_bytes).await } pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { @@ -60,13 +81,19 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - let mut output = self + let PredictDurationOutput { + phoneme_length: output, + } = self .status - .predict_duration_session_run(&model_id, phoneme_vector_array, speaker_id_array) + .run_session( + &model_id, + PredictDurationInput { + phoneme_list: ndarray::arr1(phoneme_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) .await?; + let mut output = output.into_raw_vec(); for output_item in output.iter_mut() { if *output_item < PHONEME_LENGTH_MINIMAL { @@ -95,29 +122,24 @@ impl InferenceCore { let (model_id, model_inner_id) = self.status.ids_for(style_id)?; - let length_array = NdArray::new(ndarray::arr0(length as i64)); - let vowel_phoneme_vector_array = NdArray::new(ndarray::arr1(vowel_phoneme_vector)); - let consonant_phoneme_vector_array = NdArray::new(ndarray::arr1(consonant_phoneme_vector)); - let start_accent_vector_array = NdArray::new(ndarray::arr1(start_accent_vector)); - let end_accent_vector_array = NdArray::new(ndarray::arr1(end_accent_vector)); - let start_accent_phrase_vector_array = - NdArray::new(ndarray::arr1(start_accent_phrase_vector)); - let end_accent_phrase_vector_array = NdArray::new(ndarray::arr1(end_accent_phrase_vector)); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); - - self.status - .predict_intonation_session_run( + let PredictIntonationOutput { f0_list: output } = self + .status + .run_session( &model_id, - length_array, - vowel_phoneme_vector_array, - consonant_phoneme_vector_array, - start_accent_vector_array, - end_accent_vector_array, - start_accent_phrase_vector_array, - end_accent_phrase_vector_array, - speaker_id_array, + PredictIntonationInput { + length: ndarray::arr0(length as i64), + vowel_phoneme_list: ndarray::arr1(vowel_phoneme_vector), + consonant_phoneme_list: ndarray::arr1(consonant_phoneme_vector), + start_accent_list: ndarray::arr1(start_accent_vector), + end_accent_list: ndarray::arr1(end_accent_vector), + start_accent_phrase_list: ndarray::arr1(start_accent_phrase_vector), + end_accent_phrase_list: ndarray::arr1(end_accent_phrase_vector), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, ) - .await + .await?; + + Ok(output.into_raw_vec()) } pub async fn decode( @@ -150,22 +172,26 @@ impl InferenceCore { padding_size, ); - let f0_array = NdArray::new( - ndarray::arr1(&f0_with_padding) - .into_shape([length_with_padding, 1]) - .unwrap(), - ); - let phoneme_array = NdArray::new( - ndarray::arr1(&phoneme_with_padding) - .into_shape([length_with_padding, phoneme_size]) - .unwrap(), - ); - let speaker_id_array = NdArray::new(ndarray::arr1(&[model_inner_id.raw_id().into()])); + let DecodeOutput { wave: output } = self + .status + .run_session( + &model_id, + DecodeInput { + f0: ndarray::arr1(&f0_with_padding) + .into_shape([length_with_padding, 1]) + .unwrap(), + phoneme: ndarray::arr1(&phoneme_with_padding) + .into_shape([length_with_padding, phoneme_size]) + .unwrap(), + speaker_id: ndarray::arr1(&[model_inner_id.raw_id().into()]), + }, + ) + .await?; - self.status - .decode_session_run(&model_id, f0_array, phoneme_array, speaker_id_array) - .await - .map(|output| Self::trim_padding_from_output(output, padding_size)) + Ok(Self::trim_padding_from_output( + output.into_raw_vec(), + padding_size, + )) } fn make_f0_with_padding( diff --git a/crates/voicevox_core/src/lib.rs b/crates/voicevox_core/src/lib.rs index 798515fb9..dc54551a7 100644 --- a/crates/voicevox_core/src/lib.rs +++ b/crates/voicevox_core/src/lib.rs @@ -6,13 +6,13 @@ mod devices; /// cbindgen:ignore mod engine; mod error; +mod infer; mod inference_core; mod macros; mod manifest; mod metas; mod numerics; mod result; -mod status; mod synthesizer; mod user_dict; mod version; diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs deleted file mode 100644 index 64a402683..000000000 --- a/crates/voicevox_core/src/status.rs +++ /dev/null @@ -1,504 +0,0 @@ -use super::*; -use itertools::iproduct; -use once_cell::sync::Lazy; -use onnxruntime::{ - environment::Environment, - ndarray::{Ix0, Ix1, Ix2}, - session::{NdArray, Session}, - GraphOptimizationLevel, LoggingLevel, -}; -use std::sync::Arc; -use std::{env, path::Path}; -use tracing::error; - -mod model_file; - -cfg_if! { - if #[cfg(not(feature="directml"))]{ - use onnxruntime::CudaProviderOptions; - } -} -use std::collections::BTreeMap; - -pub struct Status { - loaded_models: std::sync::Mutex, - light_session_options: SessionOptions, // 軽いモデルはこちらを使う - heavy_session_options: SessionOptions, // 重いモデルはこちらを使う -} - -#[derive(new, Getters)] -struct SessionOptions { - cpu_num_threads: u16, - use_gpu: bool, -} - -#[derive(thiserror::Error, Debug)] -#[error("不正なモデルファイルです")] -struct DecryptModelError; - -static ENVIRONMENT: Lazy = Lazy::new(|| { - cfg_if! { - if #[cfg(debug_assertions)]{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; - } else{ - const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; - } - } - Environment::builder() - .with_name(env!("CARGO_PKG_NAME")) - .with_log_level(LOGGING_LEVEL) - .build() - .unwrap() -}); - -impl Status { - pub fn new(use_gpu: bool, cpu_num_threads: u16) -> Self { - Self { - loaded_models: Default::default(), - light_session_options: SessionOptions::new(cpu_num_threads, false), - heavy_session_options: SessionOptions::new(cpu_num_threads, use_gpu), - } - } - - pub async fn load_model(&self, model: &VoiceModel) -> Result<()> { - self.loaded_models - .lock() - .unwrap() - .ensure_acceptable(model)?; - - let models = model.read_inference_models().await?; - - let predict_duration_session = self.new_session( - models.predict_duration_model(), - &self.light_session_options, - model.path(), - )?; - let predict_intonation_session = self.new_session( - models.predict_intonation_model(), - &self.light_session_options, - model.path(), - )?; - let decode_model = self.new_session( - models.decode_model(), - &self.heavy_session_options, - model.path(), - )?; - - self.loaded_models.lock().unwrap().insert( - model, - predict_duration_session, - predict_intonation_session, - decode_model, - )?; - Ok(()) - } - - pub fn unload_model(&self, voice_model_id: &VoiceModelId) -> Result<()> { - self.loaded_models.lock().unwrap().remove(voice_model_id) - } - - pub fn metas(&self) -> VoiceModelMeta { - self.loaded_models.lock().unwrap().metas() - } - - pub(crate) fn ids_for(&self, style_id: StyleId) -> Result<(VoiceModelId, ModelInnerId)> { - self.loaded_models.lock().unwrap().ids_for(style_id) - } - - pub fn is_loaded_model(&self, voice_model_id: &VoiceModelId) -> bool { - self.loaded_models - .lock() - .unwrap() - .contains_voice_model(voice_model_id) - } - - pub fn is_loaded_model_by_style_id(&self, style_id: StyleId) -> bool { - self.loaded_models.lock().unwrap().contains_style(style_id) - } - - fn new_session( - &self, - model: &[u8], - session_options: &SessionOptions, - path: impl AsRef, - ) -> LoadModelResult> { - self.new_session_from_bytes(|| model_file::decrypt(model), session_options) - .map_err(|source| LoadModelError { - path: path.as_ref().to_owned(), - context: LoadModelErrorKind::InvalidModelData, - source: Some(source), - }) - } - - fn new_session_from_bytes( - &self, - model_bytes: impl FnOnce() -> std::result::Result, DecryptModelError>, - session_options: &SessionOptions, - ) -> anyhow::Result> { - let session_builder = ENVIRONMENT - .new_session_builder()? - .with_optimization_level(GraphOptimizationLevel::Basic)? - .with_intra_op_num_threads(*session_options.cpu_num_threads() as i32)? - .with_inter_op_num_threads(*session_options.cpu_num_threads() as i32)?; - - let session_builder = if *session_options.use_gpu() { - cfg_if! { - if #[cfg(feature = "directml")]{ - session_builder - .with_disable_mem_pattern()? - .with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? - .with_append_execution_provider_directml(0)? - } else { - let options = CudaProviderOptions::default(); - session_builder.with_append_execution_provider_cuda(options)? - } - } - } else { - session_builder - }; - - Ok(session_builder.with_model_from_memory(model_bytes()?)?) - } - - pub fn validate_speaker_id(&self, style_id: StyleId) -> bool { - self.is_loaded_model_by_style_id(style_id) - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn predict_duration_session_run( - &self, - model_id: &VoiceModelId, - mut phoneme_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_duration = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_duration, .. - }| predict_duration, - ); - - tokio::task::spawn_blocking(move || { - let mut predict_duration = predict_duration.lock().unwrap(); - - let output_tensors = predict_duration - .run(vec![&mut phoneme_vector_array, &mut speaker_id_array]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - #[allow(clippy::too_many_arguments)] - pub async fn predict_intonation_session_run( - &self, - model_id: &VoiceModelId, - mut length_array: NdArray, - mut vowel_phoneme_vector_array: NdArray, - mut consonant_phoneme_vector_array: NdArray, - mut start_accent_vector_array: NdArray, - mut end_accent_vector_array: NdArray, - mut start_accent_phrase_vector_array: NdArray, - mut end_accent_phrase_vector_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let predict_intonation = self.loaded_models.lock().unwrap().get( - model_id, - |SessionSet { - predict_intonation, .. - }| predict_intonation, - ); - - tokio::task::spawn_blocking(move || { - let mut predict_intonation = predict_intonation.lock().unwrap(); - - let output_tensors = predict_intonation - .run(vec![ - &mut length_array, - &mut vowel_phoneme_vector_array, - &mut consonant_phoneme_vector_array, - &mut start_accent_vector_array, - &mut end_accent_vector_array, - &mut start_accent_phrase_vector_array, - &mut end_accent_phrase_vector_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - pub async fn decode_session_run( - &self, - model_id: &VoiceModelId, - mut f0_array: NdArray, - mut phoneme_array: NdArray, - mut speaker_id_array: NdArray, - ) -> Result> { - let decode = self - .loaded_models - .lock() - .unwrap() - .get(model_id, |SessionSet { decode, .. }| decode); - - tokio::task::spawn_blocking(move || { - let mut decode = decode.lock().unwrap(); - - let output_tensors = decode - .run(vec![ - &mut f0_array, - &mut phoneme_array, - &mut speaker_id_array, - ]) - .map_err(|e| ErrorRepr::InferenceFailed(e.into()))?; - Ok(output_tensors[0].as_slice().unwrap().to_owned()) - }) - .await - .unwrap() - } -} - -/// 読み込んだモデルの`Session`とそのメタ情報を保有し、追加/削除/取得の操作を提供する。 -/// -/// この構造体のメソッドは、すべて一瞬で完了すべきである。 -#[derive(Default)] -struct LoadedModels(BTreeMap); - -struct LoadedModel { - model_inner_ids: BTreeMap, - metas: VoiceModelMeta, - session_set: SessionSet, -} - -impl LoadedModels { - fn metas(&self) -> VoiceModelMeta { - self.0 - .values() - .flat_map(|LoadedModel { metas, .. }| metas) - .cloned() - .collect() - } - - fn ids_for(&self, style_id: StyleId) -> Result<(VoiceModelId, ModelInnerId)> { - let ( - model_id, - LoadedModel { - model_inner_ids, .. - }, - ) = self - .0 - .iter() - .find(|(_, LoadedModel { metas, .. })| { - metas - .iter() - .flat_map(SpeakerMeta::styles) - .any(|style| *style.id() == style_id) - }) - .ok_or(ErrorRepr::StyleNotFound { style_id })?; - - let model_inner_id = *model_inner_ids - .get(&style_id) - .expect("`model_inner_ids` should contains all of the style IDs in the model"); - - Ok((model_id.clone(), model_inner_id)) - } - - /// # Panics - /// - /// `self`が`model_id`を含んでいないとき、パニックする。 - fn get( - &self, - model_id: &VoiceModelId, - which: fn(&SessionSet) -> &Arc>>>, - ) -> Arc>>> { - which(&self.0[model_id].session_set).clone() - } - - fn contains_voice_model(&self, model_id: &VoiceModelId) -> bool { - self.0.contains_key(model_id) - } - - fn contains_style(&self, style_id: StyleId) -> bool { - self.styles().any(|style| *style.id() == style_id) - } - - /// 与えられた`VoiceModel`を受け入れ可能かをチェックする。 - /// - /// # Errors - /// - /// 音声モデルIDかスタイルIDが`model`と重複するとき、エラーを返す。 - fn ensure_acceptable(&self, model: &VoiceModel) -> LoadModelResult<()> { - let loaded = self.styles(); - let external = model.metas().iter().flat_map(|speaker| speaker.styles()); - - let error = |context| LoadModelError { - path: model.path().clone(), - context, - source: None, - }; - - if self.0.contains_key(model.id()) { - return Err(error(LoadModelErrorKind::ModelAlreadyLoaded { - id: model.id().clone(), - })); - } - if let Some((style, _)) = - iproduct!(loaded, external).find(|(loaded, external)| loaded.id() == external.id()) - { - return Err(error(LoadModelErrorKind::StyleAlreadyLoaded { - id: *style.id(), - })); - } - Ok(()) - } - - fn insert( - &mut self, - model: &VoiceModel, - predict_duration: Session<'static>, - predict_intonation: Session<'static>, - decode: Session<'static>, - ) -> Result<()> { - self.ensure_acceptable(model)?; - - let prev = self.0.insert( - model.id().clone(), - LoadedModel { - model_inner_ids: model.model_inner_ids(), - metas: model.metas().clone(), - session_set: SessionSet { - predict_duration: Arc::new(std::sync::Mutex::new(predict_duration.into())), - predict_intonation: Arc::new(std::sync::Mutex::new(predict_intonation.into())), - decode: Arc::new(std::sync::Mutex::new(decode.into())), - }, - }, - ); - assert!(prev.is_none()); - Ok(()) - } - - fn remove(&mut self, model_id: &VoiceModelId) -> Result<()> { - if self.0.remove(model_id).is_none() { - return Err(ErrorRepr::ModelNotFound { - model_id: model_id.clone(), - } - .into()); - } - Ok(()) - } - - fn styles(&self) -> impl Iterator { - self.0 - .values() - .flat_map(|LoadedModel { metas, .. }| metas) - .flat_map(|speaker| speaker.styles()) - } -} - -struct SessionSet { - predict_duration: Arc>>>, - predict_intonation: Arc>>>, - decode: Arc>>>, -} - -// FIXME: 以下のことをちゃんと確認した後、onnxruntime-rs側で`Session`が`Send`であると宣言する。 -// https://github.com/VOICEVOX/voicevox_core/issues/307#issuecomment-1276184614 - -use self::assert_send::AssertSend; - -mod assert_send { - use std::ops::{Deref, DerefMut}; - - use onnxruntime::session::Session; - - pub(super) struct AssertSend(T); - - impl From> for AssertSend> { - fn from(session: Session<'static>) -> Self { - Self(session) - } - } - - impl Deref for AssertSend { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl DerefMut for AssertSend { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - // SAFETY: `Session` is probably "send"able. - #[allow(unsafe_code)] - unsafe impl Send for AssertSend {} -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::macros::tests::assert_debug_fmt_eq; - use pretty_assertions::assert_eq; - - #[rstest] - #[case(true, 0)] - #[case(true, 1)] - #[case(true, 8)] - #[case(false, 2)] - #[case(false, 4)] - #[case(false, 8)] - #[case(false, 0)] - fn status_new_works(#[case] use_gpu: bool, #[case] cpu_num_threads: u16) { - let status = Status::new(use_gpu, cpu_num_threads); - assert_eq!(false, status.light_session_options.use_gpu); - assert_eq!(use_gpu, status.heavy_session_options.use_gpu); - assert_eq!( - cpu_num_threads, - status.light_session_options.cpu_num_threads - ); - assert_eq!( - cpu_num_threads, - status.heavy_session_options.cpu_num_threads - ); - assert!(status.loaded_models.lock().unwrap().0.is_empty()); - } - - #[rstest] - #[tokio::test] - async fn status_load_model_works() { - let status = Status::new(false, 0); - let result = status.load_model(&open_default_vvm_file().await).await; - assert_debug_fmt_eq!(Ok(()), result); - assert_eq!(1, status.loaded_models.lock().unwrap().0.len()); - } - - #[rstest] - #[tokio::test] - async fn status_is_model_loaded_works() { - let status = Status::new(false, 0); - let vvm = open_default_vvm_file().await; - assert!( - !status.is_loaded_model(vvm.id()), - "model should not be loaded" - ); - let result = status.load_model(&vvm).await; - assert_debug_fmt_eq!(Ok(()), result); - assert!(status.is_loaded_model(vvm.id()), "model should be loaded"); - } -} diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 98c3a5f82..594cfd856 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -1,6 +1,12 @@ use std::sync::Arc; -use crate::engine::{create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine}; +use crate::{ + engine::{ + create_kana, parse_kana, AccentPhraseModel, OpenJtalk, SynthesisEngine, + DEFAULT_SAMPLING_RATE, + }, + infer::runtimes::Onnxruntime, +}; use super::*; @@ -67,9 +73,11 @@ pub struct InitializeOptions { pub cpu_num_threads: u16, } +pub(crate) type InferenceRuntimeImpl = Onnxruntime; + /// 音声シンセサイザ。 pub struct Synthesizer { - synthesis_engine: SynthesisEngine, + synthesis_engine: SynthesisEngine, use_gpu: bool, } @@ -555,7 +563,7 @@ impl AudioQueryModel { 1., 0.1, 0.1, - SynthesisEngine::DEFAULT_SAMPLING_RATE, + DEFAULT_SAMPLING_RATE, false, Some(kana), ) diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 45e7dad17..829bbf43d 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -1,8 +1,10 @@ use async_zip::{read::fs::ZipFileReader, ZipEntry}; +use enum_map::EnumMap; use futures::future::join3; use serde::{de::DeserializeOwned, Deserialize}; use super::*; +use crate::infer::domain::InferenceOperationImpl; use std::{ collections::{BTreeMap, HashMap}, io, @@ -35,15 +37,10 @@ pub struct VoiceModel { path: PathBuf, } -#[derive(Getters)] -pub(crate) struct InferenceModels { - decode_model: Vec, - predict_duration_model: Vec, - predict_intonation_model: Vec, -} - impl VoiceModel { - pub(crate) async fn read_inference_models(&self) -> LoadModelResult { + pub(crate) async fn read_inference_models( + &self, + ) -> LoadModelResult>> { let reader = VvmEntryReader::open(&self.path).await?; let (decode_model_result, predict_duration_model_result, predict_intonation_model_result) = join3( @@ -53,11 +50,11 @@ impl VoiceModel { ) .await; - Ok(InferenceModels { - predict_duration_model: predict_duration_model_result?, - predict_intonation_model: predict_intonation_model_result?, - decode_model: decode_model_result?, - }) + Ok(EnumMap::from_array([ + predict_duration_model_result?, + predict_intonation_model_result?, + decode_model_result?, + ])) } /// VVMファイルから`VoiceModel`をコンストラクトする。 pub async fn from_path(path: impl AsRef) -> Result { diff --git a/crates/voicevox_core_c_api/Cargo.toml b/crates/voicevox_core_c_api/Cargo.toml index f187f0001..fad0e1b7b 100644 --- a/crates/voicevox_core_c_api/Cargo.toml +++ b/crates/voicevox_core_c_api/Cargo.toml @@ -52,7 +52,7 @@ easy-ext.workspace = true inventory = "0.3.4" libloading = "0.7.3" libtest-mimic = "0.6.0" -ndarray = "0.15.6" +ndarray.workspace = true ndarray-stats = "0.5.1" regex.workspace = true serde.workspace = true diff --git a/crates/voicevox_core_macros/Cargo.toml b/crates/voicevox_core_macros/Cargo.toml new file mode 100644 index 000000000..957fa3eb8 --- /dev/null +++ b/crates/voicevox_core_macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "voicevox_core_macros" +version.workspace = true +edition.workspace = true +publish.workspace = true + +[lib] +name = "macros" +proc-macro = true + +[dependencies] +indexmap.workspace = true +proc-macro2 = "1.0.69" +quote = "1.0.33" +syn = { version = "2.0.38", features = ["extra-traits"] } diff --git a/crates/voicevox_core_macros/src/inference_domain.rs b/crates/voicevox_core_macros/src/inference_domain.rs new file mode 100644 index 000000000..4a447d37d --- /dev/null +++ b/crates/voicevox_core_macros/src/inference_domain.rs @@ -0,0 +1,379 @@ +use indexmap::IndexMap; +use quote::quote; +use syn::{ + parse::{Parse, ParseStream}, + spanned::Spanned as _, + Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Field, Fields, Generics, + ItemType, Type, Variant, +}; + +pub(crate) fn derive_inference_operation( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + attrs, + vis, + ident: operation_ty_name, + generics, + data, + .. + } = input; + + deny_generics(generics)?; + + let AssocTypeDomain(domain_ty) = attrs + .iter() + .find(|a| a.path().is_ident("inference_operation")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_operation(…)]`", + ) + })? + .parse_args()?; + + let variants = unit_enum_variants(data)? + .into_iter() + .map(|(attrs, variant_name)| { + let AssocTypes { input, output } = attrs + .iter() + .find(|a| a.path().is_ident("inference_operation")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_operation(…)]`", + ) + })? + .parse_args()?; + + Ok((variant_name, (input, output))) + }) + .collect::>>()?; + + let variant_names = &variants.keys().collect::>(); + + let signatures = variants + .iter() + .map(|(variant_name, (input_ty, output_ty))| { + quote! { + #vis enum #variant_name {} + + impl crate::infer::InferenceSignature for #variant_name { + type Domain = #domain_ty; + type Input = #input_ty; + type Output = #output_ty; + + const OPERATION: ::Operation = + #operation_ty_name :: #variant_name; + } + } + }); + + return Ok(quote! { + impl crate::infer::InferenceOperation for #operation_ty_name { + const PARAM_INFOS: ::enum_map::EnumMap< + Self, + ( + &'static [crate::infer::ParamInfo], + &'static [crate::infer::ParamInfo], + ), + > = ::enum_map::EnumMap::from_array([ + #(( + <#variant_names as crate::infer::InferenceSignature>::Input::PARAM_INFOS, + <#variant_names as crate::infer::InferenceSignature>::Output::PARAM_INFOS + )),* + ]); + } + + #(#signatures)* + }); + + struct AssocTypeDomain(Type); + + impl Parse for AssocTypeDomain { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ItemType { ident, ty, .. } = input.parse()?; + + if ident != "Domain" { + return Err(syn::Error::new(ident.span(), "expected `Domain`")); + } + Ok(Self(*ty)) + } + } + + struct AssocTypes { + input: Type, + output: Type, + } + + impl Parse for AssocTypes { + fn parse(stream: ParseStream<'_>) -> syn::Result { + let mut input = None; + let mut output = None; + + while !stream.is_empty() { + let ItemType { + ident, + generics, + ty, + .. + } = stream.parse()?; + + deny_generics(&generics)?; + + *match &*ident.to_string() { + "Input" => &mut input, + "Output" => &mut output, + _ => { + return Err(syn::Error::new( + ident.span(), + "expected `Input` or `Output`", + )) + } + } = Some(*ty); + } + + let input = + input.ok_or_else(|| syn::Error::new(stream.span(), "missing `type Input = …;`"))?; + + let output = output + .ok_or_else(|| syn::Error::new(stream.span(), "missing `type Output = …;`"))?; + + Ok(Self { input, output }) + } + } + + fn deny_generics(generics: &Generics) -> syn::Result<()> { + if !generics.params.is_empty() { + return Err(syn::Error::new(generics.params.span(), "must be empty")); + } + if let Some(where_clause) = &generics.where_clause { + return Err(syn::Error::new(where_clause.span(), "must be empty")); + } + Ok(()) + } +} + +pub(crate) fn derive_inference_input_signature( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + attrs, + ident, + generics, + data, + .. + } = input; + + let AssocTypeSignature(signature) = attrs + .iter() + .find(|a| a.path().is_ident("inference_input_signature")) + .ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "missing `#[inference_input_signature(…)]`", + ) + })? + .parse_args()?; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let fields = struct_fields(data)?; + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: <<#ty as __ArrayExt>::Scalar as crate::infer::InputScalar>::KIND, + ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); + + return Ok(quote! { + impl #impl_generics crate::infer::InferenceInputSignature for #ident #ty_generics + #where_clause + { + type Signature = #signature; + + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::InputScalarKind + >] = { + trait __ArrayExt { + type Scalar: crate::infer::InputScalar; + type Dimension: ::ndarray::Dimension + 'static; + } + + impl __ArrayExt + for ::ndarray::Array + { + type Scalar = A; + type Dimension = D; + } + + &[#param_infos] + }; + + fn make_run_context( + self, + sess: &mut R::Session, + ) -> R::RunContext<'_> { + let mut ctx = as ::std::convert::From<_>>::from(sess); + #( + R::push_input(self.#field_names, &mut ctx); + )* + ctx + } + } + }); + + struct AssocTypeSignature(Type); + + impl Parse for AssocTypeSignature { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ItemType { ident, ty, .. } = input.parse()?; + + if ident != "Signature" { + return Err(syn::Error::new(ident.span(), "expected `Signature`")); + } + Ok(Self(*ty)) + } + } +} + +pub(crate) fn derive_inference_output_signature( + input: &DeriveInput, +) -> syn::Result { + let DeriveInput { + ident, + generics, + data, + .. + } = input; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let fields = struct_fields(data)?; + let num_fields = fields.len(); + + let param_infos = fields + .iter() + .map(|(name, ty)| { + let name = name.to_string(); + quote! { + crate::infer::ParamInfo { + name: ::std::borrow::Cow::Borrowed(#name), + dt: <<#ty as __ArrayExt>::Scalar as crate::infer::OutputScalar>::KIND, + ndim: <<#ty as __ArrayExt>::Dimension as ::ndarray::Dimension>::NDIM, + }, + } + }) + .collect::(); + + let field_names = fields.iter().map(|(name, _)| name); + + Ok(quote! { + impl #impl_generics crate::infer::InferenceOutputSignature for #ident #ty_generics + #where_clause + { + const PARAM_INFOS: &'static [crate::infer::ParamInfo< + crate::infer::OutputScalarKind + >] = { + trait __ArrayExt { + type Scalar: crate::infer::OutputScalar; + type Dimension: ::ndarray::Dimension + 'static; + } + + impl __ArrayExt + for ::ndarray::Array + { + type Scalar = A; + type Dimension = D; + } + + &[#param_infos] + }; + } + + impl #impl_generics ::std::convert::TryFrom<::std::vec::Vec> + for #ident #ty_generics + #where_clause + { + type Error = ::anyhow::Error; + + fn try_from( + tensors: ::std::vec::Vec, + ) -> ::std::result::Result { + ::anyhow::ensure!( + tensors.len() == #num_fields, + "expected {} tensor(s), got {}", + #num_fields, + tensors.len(), + ); + + let tensors = &mut ::std::iter::IntoIterator::into_iter(tensors); + ::std::result::Result::Ok(Self { + #( + #field_names: ::std::convert::TryInto::try_into( + ::std::iter::Iterator::next(tensors) + .expect("the length should have been checked"), + )?, + )* + }) + } + } + }) +} + +fn struct_fields(data: &Data) -> syn::Result> { + let fields = match data { + Data::Struct(DataStruct { + fields: Fields::Named(fields), + .. + }) => fields, + Data::Struct(DataStruct { fields, .. }) => { + return Err(syn::Error::new(fields.span(), "expect named fields")); + } + Data::Enum(DataEnum { enum_token, .. }) => { + return Err(syn::Error::new(enum_token.span(), "expected a struct")); + } + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected a struct")); + } + }; + + Ok(fields + .named + .iter() + .map(|Field { ident, ty, .. }| (ident.as_ref().expect("should be named"), ty)) + .collect()) +} + +fn unit_enum_variants(data: &Data) -> syn::Result> { + let variants = match data { + Data::Struct(DataStruct { struct_token, .. }) => { + return Err(syn::Error::new(struct_token.span(), "expected an enum")); + } + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Union(DataUnion { union_token, .. }) => { + return Err(syn::Error::new(union_token.span(), "expected an enum")); + } + }; + + for Variant { fields, .. } in variants { + if *fields != Fields::Unit { + return Err(syn::Error::new(fields.span(), "must be unit")); + } + } + + Ok(variants + .iter() + .map(|Variant { attrs, ident, .. }| (&**attrs, ident)) + .collect()) +} diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs new file mode 100644 index 000000000..5f2f26809 --- /dev/null +++ b/crates/voicevox_core_macros/src/lib.rs @@ -0,0 +1,104 @@ +#![warn(rust_2018_idioms)] + +mod inference_domain; + +use syn::parse_macro_input; + +/// Rust APIクレート内で、`crate::infer::InferenceDomain`の導出などを行う。 +/// +/// 次のことを行う。 +/// +/// - `InferenceDomain`の導出 +/// - 各バリアントに対する`InferenceInputSignature`の実装を、型ごと生成 +/// +/// # Example +/// +/// ``` +/// use enum_map::Enum; +/// use macros::InferenceOperation; +/// +/// pub(crate) enum InferenceDomainImpl {} +/// +/// impl InferenceDomain for InferenceDomainImpl { +/// type Operation = InferenceOperationImpl; +/// } +/// +/// #[derive(Clone, Copy, Enum, InferenceOperation)] +/// #[inference_operation( +/// type Domain = InferenceDomainImpl; +/// )] +/// pub(crate) enum InferenceOperationImpl { +/// #[inference_operation( +/// type Input = PredictDurationInput; +/// type Output = PredictDurationOutput; +/// )] +/// PredictDuration, +/// +/// #[inference_operation( +/// type Input = PredictIntonationInput; +/// type Output = PredictIntonationOutput; +/// )] +/// PredictIntonation, +/// +/// #[inference_operation( +/// type Input = DecodeInput; +/// type Output = DecodeOutput; +/// )] +/// Decode, +/// } +/// ``` +#[cfg(not(doctest))] +#[proc_macro_derive(InferenceOperation, attributes(inference_operation))] +pub fn derive_inference_operation(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_operation(input)) +} + +/// Rust APIクレート内で、`crate::infer::InferenceInputSignature`を導出する。 +/// +/// # Example +/// +/// ``` +/// use macros::InferenceInputSignature; +/// +/// #[derive(InferenceInputSignature)] +/// #[inference_input_signature( +/// type Signature = PredictDuration; +/// )] +/// pub(crate) struct PredictDurationInput { +/// pub(crate) phoneme_list: Array1, +/// pub(crate) speaker_id: Array1, +/// } +/// ``` +#[cfg(not(doctest))] +#[proc_macro_derive(InferenceInputSignature, attributes(inference_input_signature))] +pub fn derive_inference_input_signature(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_input_signature(input)) +} + +/// Rust APIクレート内で`crate::infer::InferenceInputSignature`を、`TryFrom`ごと導出 +/// する。 +/// +/// # Example +/// +/// ``` +/// use macros::InferenceOutputSignature; +/// +/// #[derive(InferenceOutputSignature)] +/// pub(crate) struct PredictDurationOutput { +/// pub(crate) phoneme_length: Array1, +/// } +/// ``` +#[cfg(not(doctest))] +#[proc_macro_derive(InferenceOutputSignature)] +pub fn derive_inference_output_signature( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = &parse_macro_input!(input); + from_syn(inference_domain::derive_inference_output_signature(input)) +} + +fn from_syn(result: syn::Result) -> proc_macro::TokenStream { + result.unwrap_or_else(|e| e.to_compile_error()).into() +}