From f25f82ecf212ccc2677e4002bb2723b9a185440f Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Mon, 27 Nov 2023 11:43:12 +0800 Subject: [PATCH] fix: fix lost bin issue --- enfer_grpc/src/bin/client.rs | 61 ++++++++++++++++ enfer_grpc/src/bin/server.rs | 130 +++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 enfer_grpc/src/bin/client.rs create mode 100644 enfer_grpc/src/bin/server.rs diff --git a/enfer_grpc/src/bin/client.rs b/enfer_grpc/src/bin/client.rs new file mode 100644 index 0000000..2e7e2be --- /dev/null +++ b/enfer_grpc/src/bin/client.rs @@ -0,0 +1,61 @@ +use std::io::{Cursor, Read}; + +use async_stream::stream; +use tokenizer::tokenizer_client::TokenizerClient; +use tokenizer::EncodeRequest; + +pub mod tokenizer { + tonic::include_proto!("tokenizer"); +} + +const model: &[u8] = include_bytes!("../../../model/model.onnx"); +const tok: &[u8] = include_bytes!("../../../model/tokenizer.json"); + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut client = TokenizerClient::connect("http://[::1]:50051").await?; + + + let mut cursor = Cursor::new(tok); + let mut buf = [0u8; 1024 * 8]; + + let response = client.set_tokenizer_json(stream! { + while let Ok(n) = cursor.read(&mut buf) { + if n == 0 { + break; + } + + yield tokenizer::TokenizerJson { + json: buf[..n].to_vec(), + }; + } + }).await?; + + println!("tokenizer RESPONSE={:?}", response); + + let mut cursor = Cursor::new(model); + let response = client.set_model(stream! { + while let Ok(n) = cursor.read(&mut buf) { + if n == 0 { + break; + } + + yield tokenizer::Model { + model: buf[..n].to_vec(), + }; + } + }).await?; + println!("model RESPONSE={:?}", response); + + + let response = client.init_model(()).await ; + + + let request = tonic::Request::new(EncodeRequest { + text: "Tonic".into(), + }); + + println!("RESPONSE={:?}", response); + + Ok(()) +} diff --git a/enfer_grpc/src/bin/server.rs b/enfer_grpc/src/bin/server.rs new file mode 100644 index 0000000..5afcf49 --- /dev/null +++ b/enfer_grpc/src/bin/server.rs @@ -0,0 +1,130 @@ +use std::pin::Pin; +use tokio::sync::Mutex; +use clap::Parser; +use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tonic::codegen::tokio_stream::StreamExt; + +use tokenizer::tokenizer_server::{Tokenizer, TokenizerServer}; +use tokenizer::{EncodeRequest, EncodeReply, GeneralResponse, Model, TokenizerJson}; + +use inference_core::embedding::Semantic; +use inference_grpc::inference_args::InferenceArgs; + +pub mod tokenizer { + tonic::include_proto!("tokenizer"); +} + +#[derive(Default)] +pub struct MyTokenizer { + sema: Mutex>>>, + + tokenzier: Mutex>, + model: Mutex>, +} + +#[tonic::async_trait] +impl Tokenizer for MyTokenizer { + async fn set_tokenizer_json(&self, reqeust: Request>) -> Result, Status> { + let mut t = self.tokenzier.lock().await; + t.clear(); + + let mut stream = reqeust.into_inner(); + while let Some(json) = stream.next().await { + let json = match json { + Ok(j) => j, + + Err(e) => return Ok(Response::new(GeneralResponse{ + success: false, + error: format!("json error: {}", e).into(), + })), + }; + t.extend(json.json); + } + + + Ok(Response::new(GeneralResponse{ + success: true, + error: None, + })) + } + + async fn set_model(&self, reqeust: Request>) -> Result, Status> { + let mut t = self.model.lock().await; + t.clear(); + + let mut stream = reqeust.into_inner(); + while let Some(model) = stream.next().await { + let model = match model { + Ok(j) => j, + + Err(e) => return Ok(Response::new(GeneralResponse{ + success: false, + error: format!("model error: {}", e).into(), + })), + }; + t.extend(model.model); + } + + + Ok(Response::new(GeneralResponse{ + success: true, + error: None, + })) + + } + + async fn init_model(&self, _: tonic::Request<()>) -> Result, Status> { + let model = self.model.lock().await; + let tokenizer = self.tokenzier.lock().await; + + + let sema = match Semantic::initialize(model.clone(), tokenizer.clone()).await { + Ok(t) => t, + Err(e) => return Ok(Response::new(GeneralResponse{ + success: false, + error: format!("sma init failed: {}", e).into(), + })), + }; + + { + let mut s = self.sema.lock().await; + s.replace(sema); + } + + Ok(Response::new(GeneralResponse{ + success: true, + error: None, + })) + + } + + async fn encode( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = tokenizer::EncodeReply { + text: format!("Hello {}!", request.into_inner().text), + }; + + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = InferenceArgs::parse(); + let addr = format!("[::1]:{}", args.port.unwrap_or_else(|| "50051".to_string())).parse().unwrap(); + let greeter = MyTokenizer::default(); + + println!("GreeterServer listening on {}", addr); + + + Server::builder() + .add_service(TokenizerServer::new(greeter)) + .serve(addr) + .await?; + + Ok(()) +}