diff --git a/inference_core/src/embed.rs b/inference_core/src/embed.rs index 80a1c5f..6481ffe 100644 --- a/inference_core/src/embed.rs +++ b/inference_core/src/embed.rs @@ -1,3 +1,4 @@ +use std::mem::ManuallyDrop; use std::pin::Pin; use std::sync::Arc; use anyhow::anyhow; @@ -9,17 +10,26 @@ use ort::{ }; pub struct Semantic { - #[allow(dead_code)] - model: Vec, + model_ref: &'static [u8], tokenizer: Arc, session: Arc>, } +impl Drop for Semantic { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut ManuallyDrop::new(self.model_ref)); + } + } +} + pub type Embedding = Vec; impl Semantic { pub async fn initialize(model: Vec, tokenizer_data: Vec) -> Result>, anyhow::Error> { + let model_ref = model.leak(); + let environment = Arc::new( Environment::builder() .with_name("Encode") @@ -36,16 +46,14 @@ impl Semantic { let tokenizer: Arc = tokenizers::Tokenizer::from_bytes(tokenizer_data).map_err(|e| anyhow!("tok frombytes error: {}", e))?.into(); - let data_ref: &[u8] = unsafe {&*( model.as_slice() as *const [u8] )}; - let semantic = Self { - model, + model_ref, tokenizer, session: SessionBuilder::new(&environment)? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(threads)? - .with_model_from_memory(data_ref) + .with_model_from_memory(model_ref) .unwrap() .into(), }; diff --git a/inference_grpc/src/bin/server.rs b/inference_grpc/src/bin/server.rs index e791007..efdd212 100644 --- a/inference_grpc/src/bin/server.rs +++ b/inference_grpc/src/bin/server.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::pin::Pin; use tokio::sync::Mutex; use clap::Parser; @@ -17,16 +18,22 @@ pub mod tokenizer { #[derive(Default)] pub struct MyTokenizer { sema: Mutex>>>, +} + +thread_local! { + static TOKENIZER_DATA: RefCell> = RefCell::default(); +} - tokenzier: Mutex>, - model: Mutex>, +thread_local! { + static MODEL: RefCell> = RefCell::default(); } + #[tonic::async_trait] impl Tokenizer for MyTokenizer { async fn set_tokenizer_json(&self, reqeust: Request>) -> Result, Status> { - let mut t = self.tokenzier.lock().await; - t.clear(); + TOKENIZER_DATA.with_borrow_mut(|t| t.clear()); + let mut stream = reqeust.into_inner(); while let Some(json) = stream.next().await { @@ -38,9 +45,9 @@ impl Tokenizer for MyTokenizer { error: format!("json error: {}", e).into(), })), }; - t.extend(json.json); + TOKENIZER_DATA.with_borrow_mut(|t| t.extend(json.json)); } - + Ok(Response::new(GeneralResponse{ success: true, @@ -49,8 +56,7 @@ impl Tokenizer for MyTokenizer { } async fn set_model(&self, reqeust: Request>) -> Result, Status> { - let mut t = self.model.lock().await; - t.clear(); + MODEL.with_borrow_mut(|t| t.clear()); let mut stream = reqeust.into_inner(); while let Some(model) = stream.next().await { @@ -62,9 +68,9 @@ impl Tokenizer for MyTokenizer { error: format!("model error: {}", e).into(), })), }; - t.extend(model.model); + MODEL.with_borrow_mut(|t| t.extend(model.model)); } - + Ok(Response::new(GeneralResponse{ success: true, @@ -74,11 +80,9 @@ impl Tokenizer for MyTokenizer { } async fn init_model(&self, _: tonic::Request<()>) -> Result, Status> { - let model = self.model.lock().await; - let tokenizer = self.tokenzier.lock().await; - + let tokenizer_data = TOKENIZER_DATA.with_borrow_mut(|t| t.clone()); - let sema = match Semantic::initialize(model.clone(), tokenizer.clone()).await { + let sema = match Semantic::initialize(MODEL.with_borrow(|it| it.clone()), tokenizer_data).await { Ok(t) => t, Err(e) => return Ok(Response::new(GeneralResponse{ success: false,