From 40f1e3bcbf07ee0d2f87ef7f25652d4e3458781d Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 23 Jan 2025 21:31:00 +0900 Subject: [PATCH] =?UTF-8?q?feat!:=20voicevox=5Fvvm=E3=81=8B=E3=82=89VVM?= =?UTF-8?q?=E3=82=92=E3=83=80=E3=82=A6=E3=83=B3=E3=83=AD=E3=83=BC=E3=83=89?= =?UTF-8?q?=E3=81=99=E3=82=8B=20(#928)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ドキュメントには手を付けないままにしてある。 BREAKING-CHANGE: `models`のダウンロード元をvoicevox_vvmに。またディレクトリ構造はvoicevox_vvmに従う。 BREAKING-CHANGE: `models`のダウンロード先のディレクトリ名を"model"から"models"に。 Refs: #825 --- Cargo.lock | 26 +++- Cargo.toml | 3 + crates/downloader/Cargo.toml | 4 + crates/downloader/src/main.rs | 264 +++++++++++++++++++++++++++++++--- 4 files changed, 275 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f3d2e0cc3..46c857f2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -1110,10 +1110,12 @@ name = "downloader" version = "0.0.0" dependencies = [ "anyhow", + "base64 0.22.1", "binstall-tar", "bytes", "clap", "comrak", + "easy-ext", "flate2", "fs-err", "futures-core", @@ -1124,8 +1126,10 @@ dependencies = [ "parse-display", "rayon", "reqwest", + "rprompt", "rstest", "scraper", + "semver", "strum", "tokio", "tracing", @@ -3247,6 +3251,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rprompt" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d24bc146fc6baf226b8bca973d5e7655bd2077a8d94d9809a060c185108e611" +dependencies = [ + "rtoolbox", + "windows-sys 0.48.0", +] + [[package]] name = "rstest" version = "0.15.0" @@ -3284,6 +3298,16 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "rtoolbox" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c247d24e63230cdb56463ae328478bd5eac8b8faa8c69461a77e8e323afac90e" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "rustc-demangle" version = "0.1.21" diff --git a/Cargo.toml b/Cargo.toml index 5ba302cca..9ad280b2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ assert_cmd = "2.0.16" async-fs = "2.1.2" async-lock = "3.4.0" async_zip = "=0.0.16" +base64 = "0.22.1" bindgen = "0.69.4" binstall-tar = "0.4.42" blocking = "1.6.1" @@ -73,9 +74,11 @@ rayon = "1.10.0" ref-cast = "1.0.23" regex = "1.11.0" reqwest = { version = "0.11.27", default-features = false } +rprompt = "2.1.1" rstest = "0.15.0" rstest_reuse = "0.6.0" scraper = "0.19.1" +semver = "1.0.14" serde = "1.0.210" serde_json = "1.0.128" serde_with = "3.10.0" diff --git a/crates/downloader/Cargo.toml b/crates/downloader/Cargo.toml index 5935780e9..46a58e87b 100644 --- a/crates/downloader/Cargo.toml +++ b/crates/downloader/Cargo.toml @@ -10,10 +10,12 @@ path = "src/main.rs" [dependencies] anyhow.workspace = true +base64.workspace = true binstall-tar.workspace = true bytes.workspace = true clap = { workspace = true, features = ["derive"] } comrak.workspace = true +easy-ext.workspace = true flate2.workspace = true fs-err = { workspace = true, features = ["tokio"] } futures-core.workspace = true @@ -24,7 +26,9 @@ octocrab = { workspace = true, default-features = false, features = ["rustls-tls parse-display.workspace = true rayon.workspace = true reqwest = { workspace = true, default-features = false, features = ["rustls-tls", "stream"] } +rprompt.workspace = true scraper.workspace = true +semver.workspace = true strum = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync"] } tracing.workspace = true diff --git a/crates/downloader/src/main.rs b/crates/downloader/src/main.rs index 7e19081ea..a5d5a86d7 100644 --- a/crates/downloader/src/main.rs +++ b/crates/downloader/src/main.rs @@ -3,28 +3,35 @@ use std::{ collections::{BTreeSet, HashSet}, env, future::Future, - io::{self, Cursor, Read}, + io::{self, Cursor, IsTerminal as _, Read}, path::{Path, PathBuf}, sync::{Arc, LazyLock}, time::Duration, }; -use anyhow::{anyhow, bail, Context as _}; +use anyhow::{anyhow, bail, ensure, Context as _}; +use base64::{prelude::BASE64_STANDARD, Engine as _}; use bytes::Bytes; use clap::{Parser as _, ValueEnum}; +use easy_ext::ext; use flate2::read::GzDecoder; use futures_core::Stream; -use futures_util::{stream::FuturesOrdered, StreamExt as _, TryStreamExt as _}; +use futures_util::{ + stream::{FuturesOrdered, FuturesUnordered}, + StreamExt as _, TryStreamExt as _, +}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use itertools::Itertools as _; use octocrab::{ models::{ - repos::{Asset, Release}, + repos::{Asset, CommitObject, Content, Release, Tag}, AssetId, }, + repos::RepoHandler, Octocrab, }; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use semver::VersionReq; use strum::{Display, IntoStaticStr}; use tokio::task::{JoinError, JoinSet}; use tracing::{info, warn}; @@ -41,6 +48,14 @@ const LIB_NAME: &str = "voicevox_core"; const DEFAULT_CORE_REPO: &str = "VOICEVOX/voicevox_core"; const DEFAULT_ONNXRUNTIME_BUILDER_REPO: &str = "VOICEVOX/onnxruntime-builder"; const DEFAULT_ADDITIONAL_LIBRARIES_REPO: &str = "VOICEVOX/voicevox_additional_libraries"; +const DEFAULT_MODELS_REPO: &str = "VOICEVOX/voicevox_vvm"; + +static ALLOWED_MODELS_VERSIONS: LazyLock = + LazyLock::new(|| "=0.0.1-preview.2".parse().unwrap()); +const MODELS_README_FILENAME: &str = "README.md"; +const MODELS_DIR_NAME: &str = "vvms"; +const MODELS_TERMS_NAME: &str = "VOICEVOX VVM TERMS OF USE"; +const MODELS_TERMS_FILE: &str = "terms.md"; static OPEN_JTALK_DIC_URL: LazyLock = LazyLock::new(|| { "https://jaist.dl.sourceforge.net/project/open-jtalk/Dictionary/open_jtalk_dic-1.11/open_jtalk_dic_utf_8-1.11.tar.gz" @@ -122,14 +137,17 @@ struct Args { default_value(DEFAULT_ADDITIONAL_LIBRARIES_REPO) )] additional_libraries_repo: RepoName, + + #[arg(long, value_name("REPOSITORY"), default_value(DEFAULT_MODELS_REPO))] + models_repo: RepoName, } #[derive(ValueEnum, Clone, Copy, PartialEq, Eq, Hash)] enum DownloadTarget { Core, - Models, Onnxruntime, AdditionalLibraries, + Models, Dict, } @@ -206,6 +224,7 @@ async fn main() -> anyhow::Result<()> { os, core_repo, onnxruntime_builder_repo, + models_repo, additional_libraries_repo, } = Args::parse(); let devices = devices.into_iter().collect::>(); @@ -226,16 +245,16 @@ async fn main() -> anyhow::Result<()> { DownloadTarget::value_variants().iter().copied().collect() }; - if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) { + if !targets.contains(&DownloadTarget::Core) { if version != "latest" { warn!( - "`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\ + "`--version={version}`が指定されていますが、`core`はダウンロード対象から\ 除外されています", ); } if core_repo.to_string() != DEFAULT_CORE_REPO { warn!( - "`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\ + "`--core-repo={core_repo}`が指定されていますが、`core`はダウンロード対象\ から除外されています", ); } @@ -268,11 +287,6 @@ async fn main() -> anyhow::Result<()> { }) .await?; - let model = find_gh_asset(octocrab, &core_repo, &version, |tag, _| { - Ok(format!("model-{tag}.zip")) - }) - .await?; - let onnxruntime = find_gh_asset( octocrab, &onnxruntime_builder_repo, @@ -284,6 +298,8 @@ async fn main() -> anyhow::Result<()> { ) .await?; + let models = find_models(octocrab, &models_repo).await?; + let additional_libraries = devices .iter() .filter(|&&device| device != Device::Cpu) @@ -325,6 +341,7 @@ async fn main() -> anyhow::Result<()> { .format(", "), ); } + info!("ダウンロードモデルバージョン: {}", models.tag); let progresses = MultiProgress::new(); @@ -338,14 +355,6 @@ async fn main() -> anyhow::Result<()> { &progresses, )?); } - if targets.contains(&DownloadTarget::Models) { - tasks.spawn(download_and_extract_from_gh( - model, - Stripping::FirstDir, - &output.join("model"), - &progresses, - )?); - } if targets.contains(&DownloadTarget::Onnxruntime) { tasks.spawn(download_and_extract_from_gh( onnxruntime, @@ -364,6 +373,13 @@ async fn main() -> anyhow::Result<()> { )?); } } + if targets.contains(&DownloadTarget::Models) { + tasks.spawn(download_models( + models, + &output.join("models"), + &progresses, + )?); + } if targets.contains(&DownloadTarget::Dict) { tasks.spawn(download_and_extract_from_url( &OPEN_JTALK_DIC_URL, @@ -524,6 +540,140 @@ fn find_onnxruntime( .with_context(|| "指定されたOS, アーキテクチャ, デバイスを含むものが見つかりませんでした") } +/// ダウンロードすべきモデル、利用規約を見つける。その際ユーザーに利用規約の同意を求める。 +async fn find_models(octocrab: &Octocrab, repo: &RepoName) -> anyhow::Result { + let repos = octocrab.repos(&repo.owner, &repo.repo); + + let (tag, sha) = repos + .list_tags() + .send() + .await? + .into_iter() + .map( + |Tag { + name, + commit: CommitObject { sha, .. }, + .. + }| { + let tag = name + .parse() + .with_context(|| format!("`{repo}` contains non-SemVer tags"))?; + Ok((tag, sha)) + }, + ) + .collect::>>()? + .into_iter() + .filter(|(version, _)| ALLOWED_MODELS_VERSIONS.matches(version)) + .sorted() + .last() + .with_context(|| format!("`{repo}`"))?; + let tag = tag.to_string(); + + let terms = repos.fetch_file_content(&sha, MODELS_TERMS_FILE).await?; + ensure_confirmation(&terms, MODELS_TERMS_NAME)?; + + let readme = repos + .fetch_file_content(&sha, MODELS_README_FILENAME) + .await?; + + let models = repos + .get_content() + .r#ref(&sha) + .path(MODELS_DIR_NAME) + .send() + .await? + .items + .into_iter() + .map( + |Content { + name, + size, + download_url, + r#type, + .. + }| { + ensure!(r#type == "file", "found directory"); + Ok(GhContent { + name, + download_url: download_url.expect("should present"), + size: size as _, + }) + }, + ) + .collect::>()?; + + return Ok(ModelsWithTerms { + tag, + readme, + terms, + models, + }); + + #[ext] + impl RepoHandler<'_> { + async fn fetch_file_content(&self, sha: &str, path: &str) -> anyhow::Result { + let Content { + encoding, content, .. + } = self + .get_content() + .r#ref(sha) + .path(path) + .send() + .await? + .items + .into_iter() + .exactly_one() + .map_err(|_| anyhow!("could not find `{path}`"))?; + + ensure!( + encoding.as_deref() == Some("base64"), + r#"expected `encoding="base64"`"#, + ); + + let content = content.expect("should present").replace('\n', ""); + let content = BASE64_STANDARD.decode(content)?; + let content = String::from_utf8(content) + .with_context(|| format!("`{path}` is not valid UTF-8"))?; + Ok(content) + } + } +} + +fn ensure_confirmation(terms: &str, terms_name: &'static str) -> anyhow::Result<()> { + eprintln!( + "----------BEGIN {terms_name}----------\n\ + {terms}\n\ + ----------END {terms_name}----------", + terms = terms.trim_end(), + ); + if !ask_yn()? { + bail!("you must agree with the term of use"); + } + return Ok(()); + + const PROMPT: &str = + "上記の利用規約に同意しますか? (Do you agree with the above terms of use?) (y/N): "; + + fn ask_yn() -> anyhow::Result { + loop { + let input = rprompt::prompt_reply_from_bufread( + &mut io::stdin().lock(), + &mut io::stderr(), + PROMPT, + )?; + if ["y", "yes"].contains(&&*input.to_lowercase()) { + break Ok(true); + } + if ["n", "no", ""].contains(&&*input.to_lowercase()) { + break Ok(false); + } + if !io::stdin().is_terminal() { + bail!("the stdin is not a TTY but received invalid input: {input:?}"); + } + } + } +} + fn download_and_extract_from_gh( GhAsset { octocrab, @@ -592,6 +742,65 @@ fn download_and_extract_from_url( }) } +fn download_models( + ModelsWithTerms { + readme, + terms, + models, + .. + }: ModelsWithTerms, + output: &Path, + progresses: &MultiProgress, +) -> anyhow::Result>> { + let output = output.to_owned(); + let reqwest = reqwest::Client::builder().build()?; + + let models = models + .into_iter() + .map(|model| { + let pb = add_progress_bar(progresses, model.size as _, model.name.clone()); + (model, pb) + }) + .collect::>(); + + Ok(async move { + fs_err::tokio::create_dir_all(&output.join(MODELS_DIR_NAME)).await?; + fs_err::tokio::write(output.join(MODELS_README_FILENAME), readme).await?; + fs_err::tokio::write(output.join(MODELS_TERMS_FILE), terms).await?; + let reqwest = &reqwest; + let output = &output; + models + .into_iter() + .map( + |( + GhContent { + name, + download_url, + size, + }, + pb, + )| async move { + let res = reqwest.get(download_url).send().await?.error_for_status()?; + let bytes_stream = res.bytes_stream().map_err(Into::into); + let pb = with_style(pb, &PROGRESS_STYLE1).await?; + let model = download(bytes_stream, Some(size), pb.clone()).await?; + let pb = tokio::task::spawn_blocking(move || { + pb.set_style(PROGRESS_STYLE2.clone()); + pb.set_message("Writing..."); + pb + }) + .await?; + fs_err::tokio::write(output.join(MODELS_DIR_NAME).join(name), model).await?; + tokio::task::spawn_blocking(move || pb.finish_with_message("Done!")).await?; + Ok(()) + }, + ) + .collect::>() + .try_collect::<()>() + .await + }) +} + fn add_progress_bar( progresses: &MultiProgress, len: u64, @@ -771,6 +980,19 @@ struct GhAsset { size: usize, } +struct ModelsWithTerms { + tag: String, + readme: String, + terms: String, + models: Vec, +} + +struct GhContent { + name: String, + download_url: String, + size: u64, +} + #[derive(Clone, Copy)] enum ArchiveKind { Zip,