forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sketch the yolo wasm example. (huggingface#546)
* Sketch the yolo wasm example. * Web ui. * Get the web ui to work. * UI tweaks. * More UI tweaks. * Use the natural width/height. * Add a link to the hf space in the readme.
- Loading branch information
1 parent
44420d8
commit 20ce3e9
Showing
13 changed files
with
1,260 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
[package] | ||
name = "candle-wasm-example-yolo" | ||
version.workspace = true | ||
edition.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
keywords.workspace = true | ||
categories.workspace = true | ||
license.workspace = true | ||
|
||
[dependencies] | ||
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" } | ||
candle-nn = { path = "../../candle-nn", version = "0.1.2" } | ||
num-traits = { workspace = true } | ||
serde = { workspace = true } | ||
serde_json = { workspace = true } | ||
image = { workspace = true } | ||
|
||
# App crates. | ||
anyhow = { workspace = true } | ||
byteorder = { workspace = true } | ||
log = { workspace = true } | ||
rand = { workspace = true } | ||
safetensors = { workspace = true } | ||
|
||
# Wasm specific crates. | ||
console_error_panic_hook = "0.1.7" | ||
getrandom = { version = "0.2", features = ["js"] } | ||
gloo = "0.8" | ||
js-sys = "0.3.64" | ||
wasm-bindgen = "0.2.87" | ||
wasm-bindgen-futures = "0.4.37" | ||
wasm-logger = "0.2" | ||
yew-agent = "0.2.0" | ||
yew = { version = "0.20.0", features = ["csr"] } | ||
|
||
[dependencies.web-sys] | ||
version = "0.3.64" | ||
features = [ | ||
'Blob', | ||
'CanvasRenderingContext2d', | ||
'Document', | ||
'Element', | ||
'HtmlElement', | ||
'HtmlCanvasElement', | ||
'HtmlImageElement', | ||
'ImageData', | ||
'Node', | ||
'Window', | ||
'Request', | ||
'RequestCache', | ||
'RequestInit', | ||
'RequestMode', | ||
'Response', | ||
'Performance', | ||
'TextMetrics', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="utf-8" /> | ||
<title>Welcome to Candle!</title> | ||
|
||
<link data-trunk rel="copy-file" href="yolo.safetensors" /> | ||
<link data-trunk rel="copy-file" href="bike.jpeg" /> | ||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" /> | ||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" /> | ||
|
||
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic"> | ||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css"> | ||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css"> | ||
</head> | ||
<body></body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
use crate::console_log; | ||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; | ||
use wasm_bindgen::prelude::*; | ||
use wasm_bindgen_futures::JsFuture; | ||
use yew::{html, Component, Context, Html}; | ||
use yew_agent::{Bridge, Bridged}; | ||
|
||
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> { | ||
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; | ||
let window = web_sys::window().ok_or("window")?; | ||
let mut opts = RequestInit::new(); | ||
let opts = opts | ||
.method("GET") | ||
.mode(RequestMode::Cors) | ||
.cache(RequestCache::NoCache); | ||
|
||
let request = Request::new_with_str_and_init(url, opts)?; | ||
|
||
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; | ||
|
||
// `resp_value` is a `Response` object. | ||
assert!(resp_value.is_instance_of::<Response>()); | ||
let resp: Response = resp_value.dyn_into()?; | ||
let data = JsFuture::from(resp.blob()?).await?; | ||
let blob = web_sys::Blob::from(data); | ||
let array_buffer = JsFuture::from(blob.array_buffer()).await?; | ||
let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); | ||
Ok(data) | ||
} | ||
|
||
pub enum Msg { | ||
Refresh, | ||
Run, | ||
UpdateStatus(String), | ||
SetModel(ModelData), | ||
WorkerInMsg(WorkerInput), | ||
WorkerOutMsg(Result<WorkerOutput, String>), | ||
} | ||
|
||
pub struct CurrentDecode { | ||
start_time: Option<f64>, | ||
} | ||
|
||
pub struct App { | ||
status: String, | ||
loaded: bool, | ||
generated: String, | ||
current_decode: Option<CurrentDecode>, | ||
worker: Box<dyn Bridge<Worker>>, | ||
} | ||
|
||
async fn model_data_load() -> Result<ModelData, JsValue> { | ||
let weights = fetch_url("yolo.safetensors").await?; | ||
console_log!("loaded weights {}", weights.len()); | ||
Ok(ModelData { weights }) | ||
} | ||
|
||
fn performance_now() -> Option<f64> { | ||
let window = web_sys::window()?; | ||
let performance = window.performance()?; | ||
Some(performance.now() / 1000.) | ||
} | ||
|
||
fn draw_bboxes(bboxes: Vec<Vec<crate::model::Bbox>>) -> Result<(), JsValue> { | ||
let document = web_sys::window().unwrap().document().unwrap(); | ||
let canvas = match document.get_element_by_id("canvas") { | ||
Some(canvas) => canvas, | ||
None => return Err("no canvas".into()), | ||
}; | ||
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into::<web_sys::HtmlCanvasElement>()?; | ||
|
||
let context = canvas | ||
.get_context("2d")? | ||
.ok_or("no 2d")? | ||
.dyn_into::<web_sys::CanvasRenderingContext2d>()?; | ||
|
||
let image_html_element = document.get_element_by_id("bike-img"); | ||
let image_html_element = match image_html_element { | ||
Some(data) => data, | ||
None => return Err("no bike-img".into()), | ||
}; | ||
let image_html_element = image_html_element.dyn_into::<web_sys::HtmlImageElement>()?; | ||
canvas.set_width(image_html_element.natural_width()); | ||
canvas.set_height(image_html_element.natural_height()); | ||
context.draw_image_with_html_image_element(&image_html_element, 0., 0.)?; | ||
context.set_stroke_style(&JsValue::from("#0dff9a")); | ||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { | ||
for b in bboxes_for_class.iter() { | ||
let name = crate::coco_classes::NAMES[class_index]; | ||
context.stroke_rect( | ||
b.xmin as f64, | ||
b.ymin as f64, | ||
(b.xmax - b.xmin) as f64, | ||
(b.ymax - b.ymin) as f64, | ||
); | ||
if let Ok(metrics) = context.measure_text(name) { | ||
let width = metrics.width(); | ||
context.set_fill_style(&"#3c8566".into()); | ||
context.fill_rect(b.xmin as f64 - 2., b.ymin as f64 - 12., width + 4., 14.); | ||
context.set_fill_style(&"#e3fff3".into()); | ||
context.fill_text(name, b.xmin as f64, b.ymin as f64 - 2.)? | ||
} | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
impl Component for App { | ||
type Message = Msg; | ||
type Properties = (); | ||
|
||
fn create(ctx: &Context<Self>) -> Self { | ||
let status = "loading weights".to_string(); | ||
let cb = { | ||
let link = ctx.link().clone(); | ||
move |e| link.send_message(Self::Message::WorkerOutMsg(e)) | ||
}; | ||
let worker = Worker::bridge(std::rc::Rc::new(cb)); | ||
Self { | ||
status, | ||
generated: String::new(), | ||
current_decode: None, | ||
worker, | ||
loaded: false, | ||
} | ||
} | ||
|
||
fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) { | ||
if first_render { | ||
ctx.link().send_future(async { | ||
match model_data_load().await { | ||
Err(err) => { | ||
let status = format!("{err:?}"); | ||
Msg::UpdateStatus(status) | ||
} | ||
Ok(model_data) => Msg::SetModel(model_data), | ||
} | ||
}); | ||
} | ||
} | ||
|
||
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { | ||
match msg { | ||
Msg::SetModel(md) => { | ||
self.status = "weights loaded succesfully!".to_string(); | ||
self.loaded = true; | ||
console_log!("loaded weights"); | ||
self.worker.send(WorkerInput::ModelData(md)); | ||
true | ||
} | ||
Msg::Run => { | ||
if self.current_decode.is_some() { | ||
self.status = "already processing some image at the moment".to_string() | ||
} else { | ||
let start_time = performance_now(); | ||
self.current_decode = Some(CurrentDecode { start_time }); | ||
self.status = "processing...".to_string(); | ||
self.generated.clear(); | ||
ctx.link().send_future(async { | ||
match fetch_url("bike.jpeg").await { | ||
Err(err) => { | ||
let status = format!("{err:?}"); | ||
Msg::UpdateStatus(status) | ||
} | ||
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)), | ||
} | ||
}); | ||
} | ||
true | ||
} | ||
Msg::WorkerOutMsg(output) => { | ||
match output { | ||
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), | ||
Ok(WorkerOutput::ProcessingDone(Err(err))) => { | ||
self.status = format!("error in worker process: {err}"); | ||
self.current_decode = None | ||
} | ||
Ok(WorkerOutput::ProcessingDone(Ok(bboxes))) => { | ||
let mut content = Vec::new(); | ||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { | ||
for b in bboxes_for_class.iter() { | ||
content.push(format!( | ||
"bbox {}: xs {:.0}-{:.0} ys {:.0}-{:.0}", | ||
crate::coco_classes::NAMES[class_index], | ||
b.xmin, | ||
b.xmax, | ||
b.ymin, | ||
b.ymax | ||
)) | ||
} | ||
} | ||
self.generated = content.join("\n"); | ||
let dt = self.current_decode.as_ref().and_then(|current_decode| { | ||
current_decode.start_time.and_then(|start_time| { | ||
performance_now().map(|stop_time| stop_time - start_time) | ||
}) | ||
}); | ||
self.status = match dt { | ||
None => "processing succeeded!".to_string(), | ||
Some(dt) => format!("processing succeeded in {:.2}s", dt,), | ||
}; | ||
self.current_decode = None; | ||
if let Err(err) = draw_bboxes(bboxes) { | ||
self.status = format!("{err:?}") | ||
} | ||
} | ||
Err(err) => { | ||
self.status = format!("error in worker {err:?}"); | ||
} | ||
} | ||
true | ||
} | ||
Msg::WorkerInMsg(inp) => { | ||
self.worker.send(inp); | ||
true | ||
} | ||
Msg::UpdateStatus(status) => { | ||
self.status = status; | ||
true | ||
} | ||
Msg::Refresh => true, | ||
} | ||
} | ||
|
||
fn view(&self, ctx: &Context<Self>) -> Html { | ||
html! { | ||
<div style="margin: 2%;"> | ||
<div><p>{"Running an object detection model in the browser using rust/wasm with "} | ||
<a href="https://github.com/huggingface/candle" target="_blank">{"candle!"}</a> | ||
</p> | ||
<p>{"Once the weights have loaded, click on the run button to process an image."}</p> | ||
<p><img id="bike-img" src="bike.jpeg"/></p> | ||
<p>{"Source: "}<a href="https://commons.wikimedia.org/wiki/File:V%C3%A9lo_parade_-_V%C3%A9lorution_-_bike_critical_mass.JPG">{"wikimedia"}</a></p> | ||
</div> | ||
{ | ||
if self.loaded{ | ||
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>) | ||
}else{ | ||
html! { <progress id="progress-bar" aria-label="Loading weights..."></progress> } | ||
} | ||
} | ||
<br/ > | ||
<h3> | ||
{&self.status} | ||
</h3> | ||
{ | ||
if self.current_decode.is_some() { | ||
html! { <progress id="progress-bar" aria-label="generating…"></progress> } | ||
} else { | ||
html! {} | ||
} | ||
} | ||
<div> | ||
<canvas id="canvas" height="150" width="150"></canvas> | ||
</div> | ||
<blockquote> | ||
<p> { self.generated.chars().map(|c| | ||
if c == '\r' || c == '\n' { | ||
html! { <br/> } | ||
} else { | ||
html! { {c} } | ||
}).collect::<Html>() | ||
} </p> | ||
</blockquote> | ||
</div> | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
fn main() { | ||
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); | ||
console_error_panic_hook::set_once(); | ||
yew::Renderer::<candle_wasm_example_yolo::App>::new().render(); | ||
} |
Oops, something went wrong.