Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
CGQAQ committed Nov 10, 2023
1 parent 05884e9 commit 983ff4a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
20 changes: 14 additions & 6 deletions inference_core/src/embed.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::anyhow;
Expand All @@ -9,17 +10,26 @@ use ort::{
};

pub struct Semantic {
#[allow(dead_code)]
model: Vec<u8>,
model_ref: &'static [u8],

tokenizer: Arc<tokenizers::Tokenizer>,
session: Arc<ort::InMemorySession<'static>>,
}

impl Drop for Semantic {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut ManuallyDrop::new(self.model_ref));
}
}
}

pub type Embedding = Vec<f32>;

impl Semantic {
pub async fn initialize(model: Vec<u8>, tokenizer_data: Vec<u8>) -> Result<Pin<Box<Semantic>>, anyhow::Error> {
let model_ref = model.leak();

let environment = Arc::new(
Environment::builder()
.with_name("Encode")
Expand All @@ -36,16 +46,14 @@ impl Semantic {

let tokenizer: Arc<tokenizers::Tokenizer> = tokenizers::Tokenizer::from_bytes(tokenizer_data).map_err(|e| anyhow!("tok frombytes error: {}", e))?.into();

let data_ref: &[u8] = unsafe {&*( model.as_slice() as *const [u8] )};

let semantic = Self {
model,
model_ref,

tokenizer,
session: SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.with_model_from_memory(data_ref)
.with_model_from_memory(model_ref)
.unwrap()
.into(),
};
Expand Down
32 changes: 18 additions & 14 deletions inference_grpc/src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::pin::Pin;
use tokio::sync::Mutex;
use clap::Parser;
Expand All @@ -17,16 +18,22 @@ pub mod tokenizer {
#[derive(Default)]
pub struct MyTokenizer {
sema: Mutex<Option<Pin<Box<Semantic>>>>,
}

thread_local! {
static TOKENIZER_DATA: RefCell<Vec<u8>> = RefCell::default();
}

tokenzier: Mutex<Vec<u8>>,
model: Mutex<Vec<u8>>,
thread_local! {
static MODEL: RefCell<Vec<u8>> = RefCell::default();
}


#[tonic::async_trait]
impl Tokenizer for MyTokenizer {
async fn set_tokenizer_json(&self, reqeust: Request<Streaming<TokenizerJson>>) -> Result<Response<GeneralResponse>, Status> {
let mut t = self.tokenzier.lock().await;
t.clear();
TOKENIZER_DATA.with_borrow_mut(|t| t.clear());


let mut stream = reqeust.into_inner();
while let Some(json) = stream.next().await {
Expand All @@ -38,9 +45,9 @@ impl Tokenizer for MyTokenizer {
error: format!("json error: {}", e).into(),
})),
};
t.extend(json.json);
TOKENIZER_DATA.with_borrow_mut(|t| t.extend(json.json));
}


Ok(Response::new(GeneralResponse{
success: true,
Expand All @@ -49,8 +56,7 @@ impl Tokenizer for MyTokenizer {
}

async fn set_model(&self, reqeust: Request<Streaming<Model>>) -> Result<Response<GeneralResponse>, Status> {
let mut t = self.model.lock().await;
t.clear();
MODEL.with_borrow_mut(|t| t.clear());

let mut stream = reqeust.into_inner();
while let Some(model) = stream.next().await {
Expand All @@ -62,9 +68,9 @@ impl Tokenizer for MyTokenizer {
error: format!("model error: {}", e).into(),
})),
};
t.extend(model.model);
MODEL.with_borrow_mut(|t| t.extend(model.model));
}


Ok(Response::new(GeneralResponse{
success: true,
Expand All @@ -74,11 +80,9 @@ impl Tokenizer for MyTokenizer {
}

async fn init_model(&self, _: tonic::Request<()>) -> Result<Response<GeneralResponse>, Status> {
let model = self.model.lock().await;
let tokenizer = self.tokenzier.lock().await;

let tokenizer_data = TOKENIZER_DATA.with_borrow_mut(|t| t.clone());

let sema = match Semantic::initialize(model.clone(), tokenizer.clone()).await {
let sema = match Semantic::initialize(MODEL.with_borrow(|it| it.clone()), tokenizer_data).await {
Ok(t) => t,
Err(e) => return Ok(Response::new(GeneralResponse{
success: false,
Expand Down

0 comments on commit 983ff4a

Please sign in to comment.