diff --git a/src/main.rs b/src/main.rs index 4d4b028..ffa8048 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,7 @@ use crate::icon_manager::ICON_MANAGER; use anyhow::Result; use axum::{ response::{IntoResponse, Redirect}, - routing::{get, post}, + routing::{delete, get, post, put}, Router, }; use clap::Parser; @@ -55,6 +55,18 @@ async fn main() -> Result<()> { "/import_user_dict", post(routes::user_dict::import_user_dict), ) + .route( + "/user_dict_word", + post(routes::user_dict::post_user_dict_word), + ) + .route( + "/user_dict_word/:word_uuid", + delete(routes::user_dict::delete_user_dict_word), + ) + .route( + "/user_dict_word/:word_uuid", + put(routes::user_dict::put_user_dict_word), + ) .route("/audio_query", post(routes::audio_query::post_audio_query)) .route( "/accent_phrases", diff --git a/src/routes/user_dict.rs b/src/routes/user_dict.rs index 7f4d53f..e5d5f32 100644 --- a/src/routes/user_dict.rs +++ b/src/routes/user_dict.rs @@ -1,6 +1,7 @@ use crate::routes::audio_query::OPEN_JTALK; use crate::voicevox::user_dict::{UserDict, UserDictWord, UserDictWordType}; +use axum::extract::{Path, Query}; use axum::response::Json; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; @@ -36,9 +37,18 @@ pub struct VvUserDictWord { mora_count: usize, surface: String, pronunciation: String, + #[serde(skip_deserializing)] part_of_speech_detail_1: String, } +#[derive(Debug, Serialize, Deserialize)] +pub struct VvUserDictWordParam { + priority: u32, + accent_type: usize, + surface: String, + pronunciation: String, +} + impl From for UserDictWord { fn from(word: VvUserDictWord) -> UserDictWord { UserDictWord::new( @@ -51,6 +61,7 @@ impl From for UserDictWord { "動詞" => UserDictWordType::Verb, "形容詞" => UserDictWordType::Adjective, "語尾" => UserDictWordType::Suffix, + "" => UserDictWordType::ProperNoun, _ => { warn!("Unknown word type: {}", &word.part_of_speech_detail_1); UserDictWordType::CommonNoun @@ -82,6 +93,19 @@ impl From for VvUserDictWord { } } +impl From for UserDictWord { + fn from(word: VvUserDictWordParam) -> UserDictWord { + UserDictWord::new( + &word.surface[..], + word.pronunciation, + word.accent_type, + UserDictWordType::CommonNoun, + word.priority, + ) + .unwrap() + } +} + pub async fn get_user_dict() -> Json> { let user_dict = USER_DICT.lock().await; @@ -110,9 +134,84 @@ pub async fn import_user_dict(Json(payload): Json) -> Result { + let mut user_dict = USER_DICT.lock().await; + + let word: UserDictWord = param.into(); + + let word_uuid = user_dict + .add_word(word) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + user_dict + .save(&USER_DICT_PATH) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + OPEN_JTALK + .lock() + .await + .use_user_dict(&user_dict) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + Ok(word_uuid.hyphenated().to_string()) +} + +pub async fn delete_user_dict_word(Path(word_uuid): Path) -> Result<()> { + let mut user_dict = USER_DICT.lock().await; + + let word_uuid = uuid::Uuid::parse_str(&word_uuid) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + user_dict + .remove_word(word_uuid) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + user_dict + .save(&USER_DICT_PATH) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + OPEN_JTALK + .lock() + .await + .use_user_dict(&user_dict) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + Ok(()) +} + +pub async fn put_user_dict_word( + Path(word_uuid): Path, + Query(payload): Query, +) -> Result<()> { + let mut user_dict = USER_DICT.lock().await; + + let word_uuid = uuid::Uuid::parse_str(&word_uuid) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + let word: UserDictWord = payload.into(); + + user_dict + .update_word(word_uuid, word) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + user_dict + .save(&USER_DICT_PATH) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + OPEN_JTALK + .lock() + .await + .use_user_dict(&user_dict) + .map_err(|e| Error::DictionaryOperationFailed(e.into()))?; + + Ok(()) +}