Skip to content

Commit

Permalink
Sketch the yolo wasm example. (huggingface#546)
Browse files Browse the repository at this point in the history
* Sketch the yolo wasm example.

* Web ui.

* Get the web ui to work.

* UI tweaks.

* More UI tweaks.

* Use the natural width/height.

* Add a link to the hf space in the readme.
  • Loading branch information
LaurentMazare authored Aug 22, 2023
1 parent 44420d8 commit 20ce3e9
Show file tree
Hide file tree
Showing 13 changed files with 1,260 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ flamegraph.svg
trace-*.json

candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
]
exclude = [
"candle-flash-attn",
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2).
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).

```rust
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
Expand Down
4 changes: 0 additions & 4 deletions candle-examples/examples/yolo-v8/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(dead_code)]

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

Expand Down Expand Up @@ -156,7 +154,6 @@ struct C2f {
cv1: ConvBlock,
cv2: ConvBlock,
bottleneck: Vec<Bottleneck>,
c: usize,
}

impl C2f {
Expand All @@ -173,7 +170,6 @@ impl C2f {
cv1,
cv2,
bottleneck,
c,
})
}
}
Expand Down
57 changes: 57 additions & 0 deletions candle-wasm-examples/yolo/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[package]
name = "candle-wasm-example-yolo"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[dependencies]
candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.1.2" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
image = { workspace = true }

# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
safetensors = { workspace = true }

# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2"
yew-agent = "0.2.0"
yew = { version = "0.20.0", features = ["csr"] }

[dependencies.web-sys]
version = "0.3.64"
features = [
'Blob',
'CanvasRenderingContext2d',
'Document',
'Element',
'HtmlElement',
'HtmlCanvasElement',
'HtmlImageElement',
'ImageData',
'Node',
'Window',
'Request',
'RequestCache',
'RequestInit',
'RequestMode',
'Response',
'Performance',
'TextMetrics',
]
17 changes: 17 additions & 0 deletions candle-wasm-examples/yolo/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Welcome to Candle!</title>

<link data-trunk rel="copy-file" href="yolo.safetensors" />
<link data-trunk rel="copy-file" href="bike.jpeg" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />

<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
</head>
<body></body>
</html>
268 changes: 268 additions & 0 deletions candle-wasm-examples/yolo/src/app.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged};

async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
let window = web_sys::window().ok_or("window")?;
let mut opts = RequestInit::new();
let opts = opts
.method("GET")
.mode(RequestMode::Cors)
.cache(RequestCache::NoCache);

let request = Request::new_with_str_and_init(url, opts)?;

let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;

// `resp_value` is a `Response` object.
assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into()?;
let data = JsFuture::from(resp.blob()?).await?;
let blob = web_sys::Blob::from(data);
let array_buffer = JsFuture::from(blob.array_buffer()).await?;
let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
Ok(data)
}

pub enum Msg {
Refresh,
Run,
UpdateStatus(String),
SetModel(ModelData),
WorkerInMsg(WorkerInput),
WorkerOutMsg(Result<WorkerOutput, String>),
}

pub struct CurrentDecode {
start_time: Option<f64>,
}

pub struct App {
status: String,
loaded: bool,
generated: String,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
}

async fn model_data_load() -> Result<ModelData, JsValue> {
let weights = fetch_url("yolo.safetensors").await?;
console_log!("loaded weights {}", weights.len());
Ok(ModelData { weights })
}

fn performance_now() -> Option<f64> {
let window = web_sys::window()?;
let performance = window.performance()?;
Some(performance.now() / 1000.)
}

fn draw_bboxes(bboxes: Vec<Vec<crate::model::Bbox>>) -> Result<(), JsValue> {
let document = web_sys::window().unwrap().document().unwrap();
let canvas = match document.get_element_by_id("canvas") {
Some(canvas) => canvas,
None => return Err("no canvas".into()),
};
let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into::<web_sys::HtmlCanvasElement>()?;

let context = canvas
.get_context("2d")?
.ok_or("no 2d")?
.dyn_into::<web_sys::CanvasRenderingContext2d>()?;

let image_html_element = document.get_element_by_id("bike-img");
let image_html_element = match image_html_element {
Some(data) => data,
None => return Err("no bike-img".into()),
};
let image_html_element = image_html_element.dyn_into::<web_sys::HtmlImageElement>()?;
canvas.set_width(image_html_element.natural_width());
canvas.set_height(image_html_element.natural_height());
context.draw_image_with_html_image_element(&image_html_element, 0., 0.)?;
context.set_stroke_style(&JsValue::from("#0dff9a"));
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
let name = crate::coco_classes::NAMES[class_index];
context.stroke_rect(
b.xmin as f64,
b.ymin as f64,
(b.xmax - b.xmin) as f64,
(b.ymax - b.ymin) as f64,
);
if let Ok(metrics) = context.measure_text(name) {
let width = metrics.width();
context.set_fill_style(&"#3c8566".into());
context.fill_rect(b.xmin as f64 - 2., b.ymin as f64 - 12., width + 4., 14.);
context.set_fill_style(&"#e3fff3".into());
context.fill_text(name, b.xmin as f64, b.ymin as f64 - 2.)?
}
}
}
Ok(())
}

impl Component for App {
type Message = Msg;
type Properties = ();

fn create(ctx: &Context<Self>) -> Self {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
move |e| link.send_message(Self::Message::WorkerOutMsg(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
generated: String::new(),
current_decode: None,
worker,
loaded: false,
}
}

fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
if first_render {
ctx.link().send_future(async {
match model_data_load().await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(model_data) => Msg::SetModel(model_data),
}
});
}
}

fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetModel(md) => {
self.status = "weights loaded succesfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));
true
}
Msg::Run => {
if self.current_decode.is_some() {
self.status = "already processing some image at the moment".to_string()
} else {
let start_time = performance_now();
self.current_decode = Some(CurrentDecode { start_time });
self.status = "processing...".to_string();
self.generated.clear();
ctx.link().send_future(async {
match fetch_url("bike.jpeg").await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)),
}
});
}
true
}
Msg::WorkerOutMsg(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::ProcessingDone(Err(err))) => {
self.status = format!("error in worker process: {err}");
self.current_decode = None
}
Ok(WorkerOutput::ProcessingDone(Ok(bboxes))) => {
let mut content = Vec::new();
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
content.push(format!(
"bbox {}: xs {:.0}-{:.0} ys {:.0}-{:.0}",
crate::coco_classes::NAMES[class_index],
b.xmin,
b.xmax,
b.ymin,
b.ymax
))
}
}
self.generated = content.join("\n");
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
})
});
self.status = match dt {
None => "processing succeeded!".to_string(),
Some(dt) => format!("processing succeeded in {:.2}s", dt,),
};
self.current_decode = None;
if let Err(err) = draw_bboxes(bboxes) {
self.status = format!("{err:?}")
}
}
Err(err) => {
self.status = format!("error in worker {err:?}");
}
}
true
}
Msg::WorkerInMsg(inp) => {
self.worker.send(inp);
true
}
Msg::UpdateStatus(status) => {
self.status = status;
true
}
Msg::Refresh => true,
}
}

fn view(&self, ctx: &Context<Self>) -> Html {
html! {
<div style="margin: 2%;">
<div><p>{"Running an object detection model in the browser using rust/wasm with "}
<a href="https://github.com/huggingface/candle" target="_blank">{"candle!"}</a>
</p>
<p>{"Once the weights have loaded, click on the run button to process an image."}</p>
<p><img id="bike-img" src="bike.jpeg"/></p>
<p>{"Source: "}<a href="https://commons.wikimedia.org/wiki/File:V%C3%A9lo_parade_-_V%C3%A9lorution_-_bike_critical_mass.JPG">{"wikimedia"}</a></p>
</div>
{
if self.loaded{
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>)
}else{
html! { <progress id="progress-bar" aria-label="Loading weights..."></progress> }
}
}
<br/ >
<h3>
{&self.status}
</h3>
{
if self.current_decode.is_some() {
html! { <progress id="progress-bar" aria-label="generating…"></progress> }
} else {
html! {}
}
}
<div>
<canvas id="canvas" height="150" width="150"></canvas>
</div>
<blockquote>
<p> { self.generated.chars().map(|c|
if c == '\r' || c == '\n' {
html! { <br/> }
} else {
html! { {c} }
}).collect::<Html>()
} </p>
</blockquote>
</div>
}
}
}
5 changes: 5 additions & 0 deletions candle-wasm-examples/yolo/src/bin/app.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
console_error_panic_hook::set_once();
yew::Renderer::<candle_wasm_example_yolo::App>::new().render();
}
Loading

0 comments on commit 20ce3e9

Please sign in to comment.