-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.rs
67 lines (55 loc) · 1.95 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use axum::{handler::post, Router, Json, AddExtensionLayer, extract::Extension};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tch::nn::ModuleT;
use tch::vision::{resnet, imagenet};
extern crate tch;
extern crate base64;
extern crate image;
struct DnnModel {
net: Mutex<Box<dyn ModuleT>>
}
#[tokio::main]
async fn main() {
let weights = std::path::Path::new("/resnet18.ot");
let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
let net:Mutex<Box<(dyn ModuleT + 'static)>> = Mutex::new(Box::new(resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT)));
let _ = vs.load(weights);
let state = Arc::new(DnnModel { net });
let app = Router::new()
.route("/", post(proc))
.layer(AddExtensionLayer::new(state));
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(Deserialize)]
struct RequestJson {
img: String,
}
#[derive(Serialize)]
struct ResponseJson {
result: Vec<String>,
}
async fn proc(Json(payload): Json<RequestJson>, Extension(state): Extension<Arc<DnnModel>>) -> Json<Value> {
let net = state.net.lock().await;
let img_buffer = base64::decode(&payload.img).unwrap();
let img = image::load_from_memory(&img_buffer.as_slice()).unwrap();
// to use load_image_and_resize224_from_memory next version tch-rs
let _ = img.save("/tmp.jpeg");
let img_tensor = imagenet::load_image_and_resize224("/tmp.jpeg").unwrap();
let output = net
.forward_t(&img_tensor.unsqueeze(0), false)
.softmax(-1, tch::Kind::Float);
let mut result = Vec::new();
for (probability, class) in imagenet::top(&output, 5).iter() {
result.push(format!("{:50} {:5.2}%", class, 100.0 * probability).to_string());
}
Json(json!({ "result": result }))
}