From 280cc79b2806dcc344774c2115690c9dc15aa804 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Fri, 24 Nov 2023 02:01:05 +0900 Subject: [PATCH] =?UTF-8?q?IO=E3=81=8C=E7=99=BA=E7=94=9F=E3=81=99=E3=82=8B?= =?UTF-8?q?=E3=83=A1=E3=82=BD=E3=83=83=E3=83=89=E3=82=92=E3=81=99=E3=81=B9?= =?UTF-8?q?=E3=81=A6async=E5=8C=96=E3=81=99=E3=82=8B=20(#667)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [skip ci] 空コミット * [skip ci] IOが発生するメソッドをすべてasync化する * テストを直す * fixtureのスコープを指定 --- .../src/__internal/doctest_fixtures.rs | 2 +- crates/voicevox_core/src/engine/open_jtalk.rs | 101 +++++++++++------- crates/voicevox_core/src/synthesizer.rs | 17 +-- crates/voicevox_core/src/user_dict/dict.rs | 73 ++++++++----- crates/voicevox_core_c_api/src/c_impls.rs | 4 +- crates/voicevox_core_c_api/src/lib.rs | 47 +++----- .../voicevox_core_java_api/src/open_jtalk.rs | 16 +-- .../voicevox_core_java_api/src/user_dict.rs | 64 ++++------- .../test/test_pseudo_raii_for_synthesizer.py | 6 +- .../python/test/test_user_dict_load.py | 4 +- .../python/test/test_user_dict_manipulate.py | 4 +- .../python/voicevox_core/_rust.pyi | 23 ++-- crates/voicevox_core_python_api/src/lib.rs | 66 +++++++----- example/python/run.py | 2 +- 14 files changed, 216 insertions(+), 213 deletions(-) diff --git a/crates/voicevox_core/src/__internal/doctest_fixtures.rs b/crates/voicevox_core/src/__internal/doctest_fixtures.rs index 9df517720..dd088b218 100644 --- a/crates/voicevox_core/src/__internal/doctest_fixtures.rs +++ b/crates/voicevox_core/src/__internal/doctest_fixtures.rs @@ -6,7 +6,7 @@ pub async fn synthesizer_with_sample_voice_model( open_jtalk_dic_dir: impl AsRef, ) -> anyhow::Result { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(open_jtalk_dic_dir).unwrap()), + Arc::new(OpenJtalk::new(open_jtalk_dic_dir).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() diff --git a/crates/voicevox_core/src/engine/open_jtalk.rs b/crates/voicevox_core/src/engine/open_jtalk.rs index f74d4130d..fa2e91304 100644 --- a/crates/voicevox_core/src/engine/open_jtalk.rs +++ b/crates/voicevox_core/src/engine/open_jtalk.rs @@ -1,4 +1,5 @@ use std::io::Write; +use std::sync::Arc; use std::{ path::{Path, PathBuf}, sync::Mutex, @@ -21,7 +22,7 @@ pub(crate) struct OpenjtalkFunctionError { /// テキスト解析器としてのOpen JTalk。 pub struct OpenJtalk { - resources: Mutex, + resources: Arc>, dict_dir: Option, } @@ -42,58 +43,76 @@ impl OpenJtalk { mecab: ManagedResource::initialize(), njd: ManagedResource::initialize(), jpcommon: ManagedResource::initialize(), - }), + }) + .into(), dict_dir: None, } } - pub fn new(open_jtalk_dict_dir: impl AsRef) -> crate::result::Result { - let mut s = Self::new_without_dic(); - s.load(open_jtalk_dict_dir).map_err(|()| { - // FIXME: 「システム辞書を読もうとしたけど読めなかった」というエラーをちゃんと用意する - ErrorRepr::NotLoadedOpenjtalkDict - })?; - Ok(s) + + pub async fn new(open_jtalk_dict_dir: impl AsRef) -> crate::result::Result { + let open_jtalk_dict_dir = open_jtalk_dict_dir.as_ref().to_owned(); + + tokio::task::spawn_blocking(move || { + let mut s = Self::new_without_dic(); + s.load(open_jtalk_dict_dir).map_err(|()| { + // FIXME: 「システム辞書を読もうとしたけど読めなかった」というエラーをちゃんと用意する + ErrorRepr::NotLoadedOpenjtalkDict + })?; + Ok(s) + }) + .await + .unwrap() } // 先に`load`を呼ぶ必要がある。 /// ユーザー辞書を設定する。 /// /// この関数を呼び出した後にユーザー辞書を変更した場合は、再度この関数を呼ぶ必要がある。 - pub fn use_user_dict(&self, user_dict: &UserDict) -> crate::result::Result<()> { + pub async fn use_user_dict(&self, user_dict: &UserDict) -> crate::result::Result<()> { let dict_dir = self .dict_dir .as_ref() .and_then(|dict_dir| dict_dir.to_str()) - .ok_or(ErrorRepr::NotLoadedOpenjtalkDict)?; + .ok_or(ErrorRepr::NotLoadedOpenjtalkDict)? + .to_owned(); + + let resources = self.resources.clone(); + + let words = user_dict.to_mecab_format(); - // ユーザー辞書用のcsvを作成 - let mut temp_csv = NamedTempFile::new().map_err(|e| ErrorRepr::UseUserDict(e.into()))?; - temp_csv - .write_all(user_dict.to_mecab_format().as_bytes()) - .map_err(|e| ErrorRepr::UseUserDict(e.into()))?; - let temp_csv_path = temp_csv.into_temp_path(); - let temp_dict = NamedTempFile::new().map_err(|e| ErrorRepr::UseUserDict(e.into()))?; - let temp_dict_path = temp_dict.into_temp_path(); + let result = tokio::task::spawn_blocking(move || -> crate::Result<_> { + // ユーザー辞書用のcsvを作成 + let mut temp_csv = + NamedTempFile::new().map_err(|e| ErrorRepr::UseUserDict(e.into()))?; + temp_csv + .write_all(words.as_ref()) + .map_err(|e| ErrorRepr::UseUserDict(e.into()))?; + let temp_csv_path = temp_csv.into_temp_path(); + let temp_dict = NamedTempFile::new().map_err(|e| ErrorRepr::UseUserDict(e.into()))?; + let temp_dict_path = temp_dict.into_temp_path(); - // Mecabでユーザー辞書をコンパイル - // TODO: エラー(SEGV)が出るパターンを把握し、それをRust側で防ぐ。 - mecab_dict_index(&[ - "mecab-dict-index", - "-d", - dict_dir, - "-u", - temp_dict_path.to_str().unwrap(), - "-f", - "utf-8", - "-t", - "utf-8", - temp_csv_path.to_str().unwrap(), - "-q", - ]); + // Mecabでユーザー辞書をコンパイル + // TODO: エラー(SEGV)が出るパターンを把握し、それをRust側で防ぐ。 + mecab_dict_index(&[ + "mecab-dict-index", + "-d", + &dict_dir, + "-u", + temp_dict_path.to_str().unwrap(), + "-f", + "utf-8", + "-t", + "utf-8", + temp_csv_path.to_str().unwrap(), + "-q", + ]); - let Resources { mecab, .. } = &mut *self.resources.lock().unwrap(); + let Resources { mecab, .. } = &mut *resources.lock().unwrap(); - let result = mecab.load_with_userdic(Path::new(dict_dir), Some(Path::new(&temp_dict_path))); + Ok(mecab.load_with_userdic(dict_dir.as_ref(), Some(Path::new(&temp_dict_path)))) + }) + .await + .unwrap()?; if !result { return Err(ErrorRepr::UseUserDict(anyhow!("辞書のコンパイルに失敗しました")).into()); @@ -269,22 +288,24 @@ mod tests { #[rstest] #[case("", Err(OpenjtalkFunctionError { function: "Mecab_get_feature", source: None }))] #[case("こんにちは、ヒホです。", Ok(testdata_hello_hiho()))] - fn extract_fullcontext_works( + #[tokio::test] + async fn extract_fullcontext_works( #[case] text: &str, #[case] expected: std::result::Result, OpenjtalkFunctionError>, ) { - let open_jtalk = OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap(); + let open_jtalk = OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap(); let result = open_jtalk.extract_fullcontext(text); assert_debug_fmt_eq!(expected, result); } #[rstest] #[case("こんにちは、ヒホです。", Ok(testdata_hello_hiho()))] - fn extract_fullcontext_loop_works( + #[tokio::test] + async fn extract_fullcontext_loop_works( #[case] text: &str, #[case] expected: std::result::Result, OpenjtalkFunctionError>, ) { - let open_jtalk = OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap(); + let open_jtalk = OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap(); for _ in 0..10 { let result = open_jtalk.extract_fullcontext(text); assert_debug_fmt_eq!(expected, result); diff --git a/crates/voicevox_core/src/synthesizer.rs b/crates/voicevox_core/src/synthesizer.rs index 8986f33cd..bb4afb153 100644 --- a/crates/voicevox_core/src/synthesizer.rs +++ b/crates/voicevox_core/src/synthesizer.rs @@ -105,7 +105,8 @@ impl Synthesizer { /// #[cfg_attr(windows, doc = "```no_run")] // https://github.com/VOICEVOX/voicevox_core/issues/537 #[cfg_attr(not(windows), doc = "```")] - /// # fn main() -> anyhow::Result<()> { + /// # #[tokio::main] + /// # async fn main() -> anyhow::Result<()> { /// # use test_util::OPEN_JTALK_DIC_DIR; /// # /// # const ACCELERATION_MODE: AccelerationMode = AccelerationMode::Cpu; @@ -115,7 +116,7 @@ impl Synthesizer { /// use voicevox_core::{AccelerationMode, InitializeOptions, OpenJtalk, Synthesizer}; /// /// let mut syntesizer = Synthesizer::new( - /// Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + /// Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), /// &InitializeOptions { /// acceleration_mode: ACCELERATION_MODE, /// ..Default::default() @@ -1428,7 +1429,7 @@ mod tests { #[case] expected_kana_text: &str, ) { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() @@ -1496,7 +1497,7 @@ mod tests { #[case] expected_text_consonant_vowel_data: &TextConsonantVowelData, ) { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() @@ -1561,7 +1562,7 @@ mod tests { #[tokio::test] async fn create_accent_phrases_works_for_japanese_commas_and_periods() { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() @@ -1620,7 +1621,7 @@ mod tests { #[tokio::test] async fn mora_length_works() { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() @@ -1656,7 +1657,7 @@ mod tests { #[tokio::test] async fn mora_pitch_works() { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() @@ -1688,7 +1689,7 @@ mod tests { #[tokio::test] async fn mora_data_works() { let syntesizer = Synthesizer::new( - Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).unwrap()), + Arc::new(OpenJtalk::new(OPEN_JTALK_DIC_DIR).await.unwrap()), &InitializeOptions { acceleration_mode: AccelerationMode::Cpu, ..Default::default() diff --git a/crates/voicevox_core/src/user_dict/dict.rs b/crates/voicevox_core/src/user_dict/dict.rs index 1a7820a78..79534d9e3 100644 --- a/crates/voicevox_core/src/user_dict/dict.rs +++ b/crates/voicevox_core/src/user_dict/dict.rs @@ -1,5 +1,3 @@ -use derive_getters::Getters; -use fs_err::File; use indexmap::IndexMap; use itertools::join; use uuid::Uuid; @@ -9,9 +7,9 @@ use crate::{error::ErrorRepr, Result}; /// ユーザー辞書。 /// 単語はJSONとの相互変換のために挿入された順序を保つ。 -#[derive(Clone, Debug, Default, Getters)] +#[derive(Debug, Default)] pub struct UserDict { - words: IndexMap, + words: std::sync::Mutex>, } impl UserDict { @@ -20,65 +18,84 @@ impl UserDict { Default::default() } + pub fn to_json(&self) -> String { + serde_json::to_string(&*self.words.lock().unwrap()).expect("should not fail") + } + + pub fn with_words(&self, f: impl FnOnce(&IndexMap) -> R) -> R { + f(&self.words.lock().unwrap()) + } + /// ユーザー辞書をファイルから読み込む。 /// /// # Errors /// /// ファイルが読めなかった、または内容が不正だった場合はエラーを返す。 - pub fn load(&mut self, store_path: &str) -> Result<()> { - let store_path = std::path::Path::new(store_path); - - let store_file = File::open(store_path).map_err(|e| ErrorRepr::LoadUserDict(e.into()))?; - - let words: IndexMap = - serde_json::from_reader(store_file).map_err(|e| ErrorRepr::LoadUserDict(e.into()))?; + pub async fn load(&self, store_path: &str) -> Result<()> { + let words = async { + let words = &fs_err::tokio::read(store_path).await?; + let words = serde_json::from_slice::>(words)?; + Ok(words) + } + .await + .map_err(ErrorRepr::LoadUserDict)?; - self.words.extend(words); + self.words.lock().unwrap().extend(words); Ok(()) } /// ユーザー辞書に単語を追加する。 - pub fn add_word(&mut self, word: UserDictWord) -> Result { + pub fn add_word(&self, word: UserDictWord) -> Result { let word_uuid = Uuid::new_v4(); - self.words.insert(word_uuid, word); + self.words.lock().unwrap().insert(word_uuid, word); Ok(word_uuid) } /// ユーザー辞書の単語を変更する。 - pub fn update_word(&mut self, word_uuid: Uuid, new_word: UserDictWord) -> Result<()> { - if !self.words.contains_key(&word_uuid) { + pub fn update_word(&self, word_uuid: Uuid, new_word: UserDictWord) -> Result<()> { + let mut words = self.words.lock().unwrap(); + if !words.contains_key(&word_uuid) { return Err(ErrorRepr::WordNotFound(word_uuid).into()); } - self.words.insert(word_uuid, new_word); + words.insert(word_uuid, new_word); Ok(()) } /// ユーザー辞書から単語を削除する。 - pub fn remove_word(&mut self, word_uuid: Uuid) -> Result { - let Some(word) = self.words.remove(&word_uuid) else { + pub fn remove_word(&self, word_uuid: Uuid) -> Result { + let Some(word) = self.words.lock().unwrap().remove(&word_uuid) else { return Err(ErrorRepr::WordNotFound(word_uuid).into()); }; Ok(word) } /// 他のユーザー辞書をインポートする。 - pub fn import(&mut self, other: &Self) -> Result<()> { - for (word_uuid, word) in &other.words { - self.words.insert(*word_uuid, word.clone()); + pub fn import(&self, other: &Self) -> Result<()> { + for (word_uuid, word) in &*other.words.lock().unwrap() { + self.words.lock().unwrap().insert(*word_uuid, word.clone()); } Ok(()) } /// ユーザー辞書を保存する。 - pub fn save(&self, store_path: &str) -> Result<()> { - let mut file = File::create(store_path).map_err(|e| ErrorRepr::SaveUserDict(e.into()))?; - serde_json::to_writer(&mut file, &self.words) - .map_err(|e| ErrorRepr::SaveUserDict(e.into()))?; - Ok(()) + pub async fn save(&self, store_path: &str) -> Result<()> { + fs_err::tokio::write( + store_path, + serde_json::to_vec(&self.words).expect("should not fail"), + ) + .await + .map_err(|e| ErrorRepr::SaveUserDict(e.into()).into()) } /// MeCabで使用する形式に変換する。 pub(crate) fn to_mecab_format(&self) -> String { - join(self.words.values().map(UserDictWord::to_mecab_format), "\n") + join( + self.words + .lock() + .unwrap() + .values() + .map(UserDictWord::to_mecab_format), + "\n", + ) } } diff --git a/crates/voicevox_core_c_api/src/c_impls.rs b/crates/voicevox_core_c_api/src/c_impls.rs index 4548444cb..74e783fdc 100644 --- a/crates/voicevox_core_c_api/src/c_impls.rs +++ b/crates/voicevox_core_c_api/src/c_impls.rs @@ -5,9 +5,9 @@ use voicevox_core::{InitializeOptions, OpenJtalk, Result, Synthesizer, VoiceMode use crate::{CApiResult, OpenJtalkRc, VoicevoxSynthesizer, VoicevoxVoiceModel}; impl OpenJtalkRc { - pub(crate) fn new(open_jtalk_dic_dir: impl AsRef) -> Result { + pub(crate) async fn new(open_jtalk_dic_dir: impl AsRef) -> Result { Ok(Self { - open_jtalk: Arc::new(OpenJtalk::new(open_jtalk_dic_dir)?), + open_jtalk: Arc::new(OpenJtalk::new(open_jtalk_dic_dir).await?), }) } } diff --git a/crates/voicevox_core_c_api/src/lib.rs b/crates/voicevox_core_c_api/src/lib.rs index 302089a95..ac72b6359 100644 --- a/crates/voicevox_core_c_api/src/lib.rs +++ b/crates/voicevox_core_c_api/src/lib.rs @@ -134,7 +134,9 @@ pub unsafe extern "C" fn voicevox_open_jtalk_rc_new( ) -> VoicevoxResultCode { into_result_code_with_error((|| { let open_jtalk_dic_dir = ensure_utf8(CStr::from_ptr(open_jtalk_dic_dir))?; - let open_jtalk = OpenJtalkRc::new(open_jtalk_dic_dir)?.into(); + let open_jtalk = RUNTIME + .block_on(OpenJtalkRc::new(open_jtalk_dic_dir))? + .into(); out_open_jtalk.as_ptr().write_unaligned(open_jtalk); Ok(()) })()) @@ -157,11 +159,7 @@ pub extern "C" fn voicevox_open_jtalk_rc_use_user_dict( user_dict: &VoicevoxUserDict, ) -> VoicevoxResultCode { into_result_code_with_error((|| { - let user_dict = user_dict.to_owned(); - { - let dict = user_dict.dict.as_ref().lock().expect("lock failed"); - open_jtalk.open_jtalk.use_user_dict(&dict)?; - } + RUNTIME.block_on(open_jtalk.open_jtalk.use_user_dict(&user_dict.dict))?; Ok(()) })()) } @@ -1036,7 +1034,7 @@ pub extern "C" fn voicevox_error_result_to_message( /// ユーザー辞書。 #[derive(Default)] pub struct VoicevoxUserDict { - dict: Arc>, + dict: Arc, } /// ユーザー辞書の単語。 @@ -1116,8 +1114,7 @@ pub unsafe extern "C" fn voicevox_user_dict_load( ) -> VoicevoxResultCode { into_result_code_with_error((|| { let dict_path = ensure_utf8(unsafe { CStr::from_ptr(dict_path) })?; - let mut dict = user_dict.dict.lock().unwrap(); - dict.load(dict_path)?; + RUNTIME.block_on(user_dict.dict.load(dict_path))?; Ok(()) })()) @@ -1146,10 +1143,7 @@ pub unsafe extern "C" fn voicevox_user_dict_add_word( ) -> VoicevoxResultCode { into_result_code_with_error((|| { let word = word.read_unaligned().try_into_word()?; - let uuid = { - let mut dict = user_dict.dict.lock().expect("lock failed"); - dict.add_word(word)? - }; + let uuid = user_dict.dict.add_word(word)?; output_word_uuid.as_ptr().copy_from(uuid.as_bytes(), 16); Ok(()) @@ -1177,10 +1171,7 @@ pub unsafe extern "C" fn voicevox_user_dict_update_word( into_result_code_with_error((|| { let word_uuid = Uuid::from_slice(word_uuid).map_err(CApiError::InvalidUuid)?; let word = word.read_unaligned().try_into_word()?; - { - let mut dict = user_dict.dict.lock().expect("lock failed"); - dict.update_word(word_uuid, word)?; - }; + user_dict.dict.update_word(word_uuid, word)?; Ok(()) })()) @@ -1203,11 +1194,7 @@ pub extern "C" fn voicevox_user_dict_remove_word( ) -> VoicevoxResultCode { into_result_code_with_error((|| { let word_uuid = Uuid::from_slice(word_uuid).map_err(CApiError::InvalidUuid)?; - { - let mut dict = user_dict.dict.lock().expect("lock failed"); - dict.remove_word(word_uuid)?; - }; - + user_dict.dict.remove_word(word_uuid)?; Ok(()) })()) } @@ -1229,8 +1216,7 @@ pub unsafe extern "C" fn voicevox_user_dict_to_json( user_dict: &VoicevoxUserDict, output_json: NonNull<*mut c_char>, ) -> VoicevoxResultCode { - let dict = user_dict.dict.lock().expect("lock failed"); - let json = serde_json::to_string(&dict.words()).expect("should be always valid"); + let json = user_dict.dict.to_json(); let json = CString::new(json).expect("\\0を含まない文字列であることが保証されている"); output_json .as_ptr() @@ -1253,12 +1239,7 @@ pub extern "C" fn voicevox_user_dict_import( other_dict: &VoicevoxUserDict, ) -> VoicevoxResultCode { into_result_code_with_error((|| { - { - let mut dict = user_dict.dict.lock().expect("lock failed"); - let other_dict = other_dict.dict.lock().expect("lock failed"); - dict.import(&other_dict)?; - }; - + user_dict.dict.import(&other_dict.dict)?; Ok(()) })()) } @@ -1279,11 +1260,7 @@ pub unsafe extern "C" fn voicevox_user_dict_save( ) -> VoicevoxResultCode { into_result_code_with_error((|| { let path = ensure_utf8(CStr::from_ptr(path))?; - { - let dict = user_dict.dict.lock().expect("lock failed"); - dict.save(path)?; - }; - + RUNTIME.block_on(user_dict.dict.save(path))?; Ok(()) })()) } diff --git a/crates/voicevox_core_java_api/src/open_jtalk.rs b/crates/voicevox_core_java_api/src/open_jtalk.rs index 020cf38f3..bbb03608c 100644 --- a/crates/voicevox_core_java_api/src/open_jtalk.rs +++ b/crates/voicevox_core_java_api/src/open_jtalk.rs @@ -1,9 +1,6 @@ -use std::{ - borrow::Cow, - sync::{Arc, Mutex}, -}; +use std::{borrow::Cow, sync::Arc}; -use crate::common::throw_if_err; +use crate::common::{throw_if_err, RUNTIME}; use jni::{ objects::{JObject, JString}, JNIEnv, @@ -19,7 +16,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_OpenJtalk_rsNew<'local> let open_jtalk_dict_dir = env.get_string(&open_jtalk_dict_dir)?; let open_jtalk_dict_dir = &*Cow::from(&open_jtalk_dict_dir); - let internal = voicevox_core::OpenJtalk::new(open_jtalk_dict_dir)?; + let internal = RUNTIME.block_on(voicevox_core::OpenJtalk::new(open_jtalk_dict_dir))?; env.set_rust_field(&this, "handle", Arc::new(internal))?; Ok(()) @@ -38,13 +35,10 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_OpenJtalk_rsUseUserDict .clone(); let user_dict = env - .get_rust_field::<_, _, Arc>>(&user_dict, "handle")? + .get_rust_field::<_, _, Arc>(&user_dict, "handle")? .clone(); - { - let user_dict = user_dict.lock().unwrap(); - internal.use_user_dict(&user_dict)? - } + RUNTIME.block_on(internal.use_user_dict(&user_dict))?; Ok(()) }) diff --git a/crates/voicevox_core_java_api/src/user_dict.rs b/crates/voicevox_core_java_api/src/user_dict.rs index e85085a34..abc90253f 100644 --- a/crates/voicevox_core_java_api/src/user_dict.rs +++ b/crates/voicevox_core_java_api/src/user_dict.rs @@ -1,10 +1,7 @@ use jni::objects::JClass; -use std::{ - borrow::Cow, - sync::{Arc, Mutex}, -}; +use std::{borrow::Cow, sync::Arc}; -use crate::common::{throw_if_err, JavaApiError}; +use crate::common::{throw_if_err, JavaApiError, RUNTIME}; use jni::{ objects::{JObject, JString}, sys::jobject, @@ -19,7 +16,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsNew<'local>( throw_if_err(env, (), |env| { let internal = voicevox_core::UserDict::new(); - env.set_rust_field(&this, "handle", Arc::new(Mutex::new(internal)))?; + env.set_rust_field(&this, "handle", Arc::new(internal))?; Ok(()) }) @@ -33,7 +30,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsAddWord<'loc ) -> jobject { throw_if_err(env, std::ptr::null_mut(), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let word_json = env.get_string(&word_json)?; @@ -42,12 +39,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsAddWord<'loc let word: voicevox_core::UserDictWord = serde_json::from_str(word_json).map_err(JavaApiError::DeJson)?; - let uuid = { - let mut internal = internal.lock().unwrap(); - internal.add_word(word)? - }; - - let uuid = uuid.hyphenated().to_string(); + let uuid = internal.add_word(word)?.hyphenated().to_string(); let uuid = env.new_string(uuid)?; Ok(uuid.into_raw()) @@ -63,7 +55,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsUpdateWord<' ) { throw_if_err(env, (), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let uuid = env.get_string(&uuid)?; @@ -74,10 +66,7 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsUpdateWord<' let word: voicevox_core::UserDictWord = serde_json::from_str(word_json).map_err(JavaApiError::DeJson)?; - { - let mut internal = internal.lock().unwrap(); - internal.update_word(uuid, word)?; - }; + internal.update_word(uuid, word)?; Ok(()) }) @@ -91,16 +80,13 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsRemoveWord<' ) { throw_if_err(env, (), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let uuid = env.get_string(&uuid)?; let uuid = Cow::from(&uuid).parse()?; - { - let mut internal = internal.lock().unwrap(); - internal.remove_word(uuid)?; - }; + internal.remove_word(uuid)?; Ok(()) }) @@ -114,17 +100,13 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsImportDict<' ) { throw_if_err(env, (), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let other_dict = env - .get_rust_field::<_, _, Arc>>(&other_dict, "handle")? + .get_rust_field::<_, _, Arc>(&other_dict, "handle")? .clone(); - { - let mut internal = internal.lock().unwrap(); - let other_dict = other_dict.lock().unwrap(); - internal.import(&other_dict)?; - } + internal.import(&other_dict)?; Ok(()) }) @@ -138,16 +120,13 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsLoad<'local> ) { throw_if_err(env, (), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let path = env.get_string(&path)?; let path = &Cow::from(&path); - { - let mut internal = internal.lock().unwrap(); - internal.load(path)?; - }; + RUNTIME.block_on(internal.load(path))?; Ok(()) }) @@ -161,16 +140,13 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsSave<'local> ) { throw_if_err(env, (), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); let path = env.get_string(&path)?; let path = &Cow::from(&path); - { - let internal = internal.lock().unwrap(); - internal.save(path)?; - }; + RUNTIME.block_on(internal.save(path))?; Ok(()) }) @@ -183,14 +159,10 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_UserDict_rsGetWords<'lo ) -> jobject { throw_if_err(env, std::ptr::null_mut(), |env| { let internal = env - .get_rust_field::<_, _, Arc>>(&this, "handle")? + .get_rust_field::<_, _, Arc>(&this, "handle")? .clone(); - let words = { - let internal = internal.lock().unwrap(); - serde_json::to_string(internal.words()).expect("should not fail") - }; - + let words = internal.to_json(); let words = env.new_string(words)?; Ok(words.into_raw()) diff --git a/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py b/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py index 9859b85d7..a40c9c160 100644 --- a/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py +++ b/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py @@ -40,6 +40,6 @@ async def synthesizer(open_jtalk: OpenJtalk) -> Synthesizer: return Synthesizer(open_jtalk) -@pytest.fixture(scope="module") -def open_jtalk() -> OpenJtalk: - return OpenJtalk(conftest.open_jtalk_dic_dir) +@pytest_asyncio.fixture(scope="function") +async def open_jtalk() -> OpenJtalk: + return await OpenJtalk.new(conftest.open_jtalk_dic_dir) diff --git a/crates/voicevox_core_python_api/python/test/test_user_dict_load.py b/crates/voicevox_core_python_api/python/test/test_user_dict_load.py index ba009f37d..572046496 100644 --- a/crates/voicevox_core_python_api/python/test/test_user_dict_load.py +++ b/crates/voicevox_core_python_api/python/test/test_user_dict_load.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio async def test_user_dict_load() -> None: - open_jtalk = voicevox_core.OpenJtalk(conftest.open_jtalk_dic_dir) + open_jtalk = await voicevox_core.OpenJtalk.new(conftest.open_jtalk_dic_dir) model = await voicevox_core.VoiceModel.from_path(conftest.model_dir) synthesizer = voicevox_core.Synthesizer(open_jtalk) @@ -29,7 +29,7 @@ async def test_user_dict_load() -> None: ) assert isinstance(uuid, UUID) - open_jtalk.use_user_dict(temp_dict) + await open_jtalk.use_user_dict(temp_dict) audio_query_with_dict = await synthesizer.audio_query( "this_word_should_not_exist_in_default_dictionary", style_id=0 diff --git a/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py b/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py index c283b48f1..1ba37465f 100644 --- a/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py +++ b/crates/voicevox_core_python_api/python/test/test_user_dict_manipulate.py @@ -59,8 +59,8 @@ async def test_user_dict_load() -> None: ) temp_path_fd, temp_path = tempfile.mkstemp() os.close(temp_path_fd) - dict_c.save(temp_path) - dict_a.load(temp_path) + await dict_c.save(temp_path) + await dict_a.load(temp_path) assert uuid_a in dict_a.words assert uuid_c in dict_a.words diff --git a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi index 5288fcbde..b09f8425f 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi +++ b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi @@ -56,15 +56,20 @@ class VoiceModel: class OpenJtalk: """ テキスト解析器としてのOpen JTalk。 - - Parameters - ---------- - open_jtalk_dict_dir - Open JTalkの辞書ディレクトリ。 """ - def __init__(self, open_jtalk_dict_dir: Union[Path, str]) -> None: ... - def use_user_dict(self, user_dict: UserDict) -> None: + @staticmethod + async def new(open_jtalk_dict_dir: Union[Path, str]) -> "OpenJtalk": + """ + ``OpenJTalk`` を生成する。 + + Parameters + ---------- + open_jtalk_dict_dir + Open JTalkの辞書ディレクトリ。 + """ + ... + async def use_user_dict(self, user_dict: UserDict) -> None: """ ユーザー辞書を設定する。 @@ -357,7 +362,7 @@ class UserDict: """このオプジェクトの :class:`dict` としての表現。""" ... def __init__(self) -> None: ... - def load(self, path: str) -> None: + async def load(self, path: str) -> None: """ファイルに保存されたユーザー辞書を読み込む。 Parameters @@ -366,7 +371,7 @@ class UserDict: ユーザー辞書のパス。 """ ... - def save(self, path: str) -> None: + async def save(self, path: str) -> None: """ ユーザー辞書をファイルに保存する。 diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index 1531b463e..e96a3d8c4 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -119,22 +119,26 @@ struct OpenJtalk { #[pymethods] impl OpenJtalk { - #[new] + #[allow(clippy::new_ret_no_self)] + #[staticmethod] fn new( #[pyo3(from_py_with = "from_utf8_path")] open_jtalk_dict_dir: String, py: Python<'_>, - ) -> PyResult { - Ok(Self { - open_jtalk: Arc::new( - voicevox_core::OpenJtalk::new(open_jtalk_dict_dir).into_py_result(py)?, - ), + ) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + let open_jtalk = voicevox_core::OpenJtalk::new(open_jtalk_dict_dir).await; + let open_jtalk = Python::with_gil(|py| open_jtalk.into_py_result(py))?.into(); + Ok(Self { open_jtalk }) }) } - fn use_user_dict(&self, user_dict: UserDict, py: Python<'_>) -> PyResult<()> { - self.open_jtalk - .use_user_dict(&user_dict.dict) - .into_py_result(py) + fn use_user_dict<'py>(&self, user_dict: UserDict, py: Python<'py>) -> PyResult<&'py PyAny> { + let this = self.open_jtalk.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let result = this.use_user_dict(&user_dict.dict).await; + Python::with_gil(|py| result.into_py_result(py)) + }) } } @@ -526,7 +530,7 @@ fn _to_zenkaku(text: &str) -> PyResult { #[pyclass] #[derive(Default, Debug, Clone)] struct UserDict { - dict: voicevox_core::UserDict, + dict: Arc, } #[pymethods] @@ -536,12 +540,24 @@ impl UserDict { Self::default() } - fn load(&mut self, path: &str, py: Python<'_>) -> PyResult<()> { - self.dict.load(path).into_py_result(py) + fn load<'py>(&self, path: &str, py: Python<'py>) -> PyResult<&'py PyAny> { + let this = self.dict.clone(); + let path = path.to_owned(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let result = this.load(&path).await; + Python::with_gil(|py| result.into_py_result(py)) + }) } - fn save(&self, path: &str, py: Python<'_>) -> PyResult<()> { - self.dict.save(path).into_py_result(py) + fn save<'py>(&self, path: &str, py: Python<'py>) -> PyResult<&'py PyAny> { + let this = self.dict.clone(); + let path = path.to_owned(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let result = this.save(&path).await; + Python::with_gil(|py| result.into_py_result(py)) + }) } fn add_word( @@ -580,16 +596,16 @@ impl UserDict { #[getter] fn words<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> { - let words = self - .dict - .words() - .iter() - .map(|(&uuid, word)| { - let uuid = to_py_uuid(py, uuid)?; - let word = to_py_user_dict_word(py, word)?; - Ok((uuid, word)) - }) - .collect::>>()?; + let words = self.dict.with_words(|words| { + words + .iter() + .map(|(&uuid, word)| { + let uuid = to_py_uuid(py, uuid)?; + let word = to_py_user_dict_word(py, word)?; + Ok((uuid, word)) + }) + .collect::>>() + })?; Ok(words.into_py_dict(py)) } } diff --git a/example/python/run.py b/example/python/run.py index 27b787b00..fb39715e9 100644 --- a/example/python/run.py +++ b/example/python/run.py @@ -37,7 +37,7 @@ async def main() -> None: logger.info("%s", f"Initializing ({acceleration_mode=}, {open_jtalk_dict_dir=})") synthesizer = Synthesizer( - OpenJtalk(open_jtalk_dict_dir), acceleration_mode=acceleration_mode + await OpenJtalk.new(open_jtalk_dict_dir), acceleration_mode=acceleration_mode ) logger.debug("%s", f"{synthesizer.metas=}")