diff --git a/.gitignore b/.gitignore index 1585ed119b..7269ec8aa4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ flamegraph.svg trace-*.json candle-wasm-examples/*/*.bin +candle-wasm-examples/*/*.jpeg candle-wasm-examples/*/*.wav candle-wasm-examples/*/*.safetensors candle-wasm-examples/*/package-lock.json diff --git a/Cargo.toml b/Cargo.toml index aa606534b6..7957c03841 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "candle-transformers", "candle-wasm-examples/llama2-c", "candle-wasm-examples/whisper", + "candle-wasm-examples/yolo", ] exclude = [ "candle-flash-attn", diff --git a/README.md b/README.md index d2a54120a5..76e27a4002 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: [whisper](https://huggingface.co/spaces/lmz/candle-whisper), -[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2). +[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2), +[yolo](https://huggingface.co/spaces/lmz/candle-yolo). ```rust let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?; diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 45ff3d2d2b..5db642e0c2 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -156,7 +154,6 @@ struct C2f { cv1: ConvBlock, cv2: ConvBlock, bottleneck: Vec, - c: usize, } impl C2f { @@ -173,7 +170,6 @@ impl C2f { cv1, cv2, bottleneck, - c, }) } } diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml new file mode 100644 index 0000000000..ef9498ee4a --- /dev/null +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "candle-wasm-example-yolo" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.1.2" } +num-traits = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +image = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +safetensors = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +getrandom = { version = "0.2", features = ["js"] } +gloo = "0.8" +js-sys = "0.3.64" +wasm-bindgen = "0.2.87" +wasm-bindgen-futures = "0.4.37" +wasm-logger = "0.2" +yew-agent = "0.2.0" +yew = { version = "0.20.0", features = ["csr"] } + +[dependencies.web-sys] +version = "0.3.64" +features = [ + 'Blob', + 'CanvasRenderingContext2d', + 'Document', + 'Element', + 'HtmlElement', + 'HtmlCanvasElement', + 'HtmlImageElement', + 'ImageData', + 'Node', + 'Window', + 'Request', + 'RequestCache', + 'RequestInit', + 'RequestMode', + 'Response', + 'Performance', + 'TextMetrics', +] diff --git a/candle-wasm-examples/yolo/index.html b/candle-wasm-examples/yolo/index.html new file mode 100644 index 0000000000..c64051ee9b --- /dev/null +++ b/candle-wasm-examples/yolo/index.html @@ -0,0 +1,17 @@ + + + + + Welcome to Candle! + + + + + + + + + + + + diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs new file mode 100644 index 0000000000..bd999b6c9e --- /dev/null +++ b/candle-wasm-examples/yolo/src/app.rs @@ -0,0 +1,268 @@ +use crate::console_log; +use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; +use yew::{html, Component, Context, Html}; +use yew_agent::{Bridge, Bridged}; + +async fn fetch_url(url: &str) -> Result, JsValue> { + use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; + let window = web_sys::window().ok_or("window")?; + let mut opts = RequestInit::new(); + let opts = opts + .method("GET") + .mode(RequestMode::Cors) + .cache(RequestCache::NoCache); + + let request = Request::new_with_str_and_init(url, opts)?; + + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + // `resp_value` is a `Response` object. + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into()?; + let data = JsFuture::from(resp.blob()?).await?; + let blob = web_sys::Blob::from(data); + let array_buffer = JsFuture::from(blob.array_buffer()).await?; + let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); + Ok(data) +} + +pub enum Msg { + Refresh, + Run, + UpdateStatus(String), + SetModel(ModelData), + WorkerInMsg(WorkerInput), + WorkerOutMsg(Result), +} + +pub struct CurrentDecode { + start_time: Option, +} + +pub struct App { + status: String, + loaded: bool, + generated: String, + current_decode: Option, + worker: Box>, +} + +async fn model_data_load() -> Result { + let weights = fetch_url("yolo.safetensors").await?; + console_log!("loaded weights {}", weights.len()); + Ok(ModelData { weights }) +} + +fn performance_now() -> Option { + let window = web_sys::window()?; + let performance = window.performance()?; + Some(performance.now() / 1000.) +} + +fn draw_bboxes(bboxes: Vec>) -> Result<(), JsValue> { + let document = web_sys::window().unwrap().document().unwrap(); + let canvas = match document.get_element_by_id("canvas") { + Some(canvas) => canvas, + None => return Err("no canvas".into()), + }; + let canvas: web_sys::HtmlCanvasElement = canvas.dyn_into::()?; + + let context = canvas + .get_context("2d")? + .ok_or("no 2d")? + .dyn_into::()?; + + let image_html_element = document.get_element_by_id("bike-img"); + let image_html_element = match image_html_element { + Some(data) => data, + None => return Err("no bike-img".into()), + }; + let image_html_element = image_html_element.dyn_into::()?; + canvas.set_width(image_html_element.natural_width()); + canvas.set_height(image_html_element.natural_height()); + context.draw_image_with_html_image_element(&image_html_element, 0., 0.)?; + context.set_stroke_style(&JsValue::from("#0dff9a")); + for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { + for b in bboxes_for_class.iter() { + let name = crate::coco_classes::NAMES[class_index]; + context.stroke_rect( + b.xmin as f64, + b.ymin as f64, + (b.xmax - b.xmin) as f64, + (b.ymax - b.ymin) as f64, + ); + if let Ok(metrics) = context.measure_text(name) { + let width = metrics.width(); + context.set_fill_style(&"#3c8566".into()); + context.fill_rect(b.xmin as f64 - 2., b.ymin as f64 - 12., width + 4., 14.); + context.set_fill_style(&"#e3fff3".into()); + context.fill_text(name, b.xmin as f64, b.ymin as f64 - 2.)? + } + } + } + Ok(()) +} + +impl Component for App { + type Message = Msg; + type Properties = (); + + fn create(ctx: &Context) -> Self { + let status = "loading weights".to_string(); + let cb = { + let link = ctx.link().clone(); + move |e| link.send_message(Self::Message::WorkerOutMsg(e)) + }; + let worker = Worker::bridge(std::rc::Rc::new(cb)); + Self { + status, + generated: String::new(), + current_decode: None, + worker, + loaded: false, + } + } + + fn rendered(&mut self, ctx: &Context, first_render: bool) { + if first_render { + ctx.link().send_future(async { + match model_data_load().await { + Err(err) => { + let status = format!("{err:?}"); + Msg::UpdateStatus(status) + } + Ok(model_data) => Msg::SetModel(model_data), + } + }); + } + } + + fn update(&mut self, ctx: &Context, msg: Self::Message) -> bool { + match msg { + Msg::SetModel(md) => { + self.status = "weights loaded succesfully!".to_string(); + self.loaded = true; + console_log!("loaded weights"); + self.worker.send(WorkerInput::ModelData(md)); + true + } + Msg::Run => { + if self.current_decode.is_some() { + self.status = "already processing some image at the moment".to_string() + } else { + let start_time = performance_now(); + self.current_decode = Some(CurrentDecode { start_time }); + self.status = "processing...".to_string(); + self.generated.clear(); + ctx.link().send_future(async { + match fetch_url("bike.jpeg").await { + Err(err) => { + let status = format!("{err:?}"); + Msg::UpdateStatus(status) + } + Ok(image_data) => Msg::WorkerInMsg(WorkerInput::Run(image_data)), + } + }); + } + true + } + Msg::WorkerOutMsg(output) => { + match output { + Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), + Ok(WorkerOutput::ProcessingDone(Err(err))) => { + self.status = format!("error in worker process: {err}"); + self.current_decode = None + } + Ok(WorkerOutput::ProcessingDone(Ok(bboxes))) => { + let mut content = Vec::new(); + for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { + for b in bboxes_for_class.iter() { + content.push(format!( + "bbox {}: xs {:.0}-{:.0} ys {:.0}-{:.0}", + crate::coco_classes::NAMES[class_index], + b.xmin, + b.xmax, + b.ymin, + b.ymax + )) + } + } + self.generated = content.join("\n"); + let dt = self.current_decode.as_ref().and_then(|current_decode| { + current_decode.start_time.and_then(|start_time| { + performance_now().map(|stop_time| stop_time - start_time) + }) + }); + self.status = match dt { + None => "processing succeeded!".to_string(), + Some(dt) => format!("processing succeeded in {:.2}s", dt,), + }; + self.current_decode = None; + if let Err(err) = draw_bboxes(bboxes) { + self.status = format!("{err:?}") + } + } + Err(err) => { + self.status = format!("error in worker {err:?}"); + } + } + true + } + Msg::WorkerInMsg(inp) => { + self.worker.send(inp); + true + } + Msg::UpdateStatus(status) => { + self.status = status; + true + } + Msg::Refresh => true, + } + } + + fn view(&self, ctx: &Context) -> Html { + html! { +
+

{"Running an object detection model in the browser using rust/wasm with "} + {"candle!"} +

+

{"Once the weights have loaded, click on the run button to process an image."}

+

+

{"Source: "}{"wikimedia"}

+
+ { + if self.loaded{ + html!() + }else{ + html! { } + } + } +
+

+ {&self.status} +

+ { + if self.current_decode.is_some() { + html! { } + } else { + html! {} + } + } +
+ +
+
+

{ self.generated.chars().map(|c| + if c == '\r' || c == '\n' { + html! {
} + } else { + html! { {c} } + }).collect::() + }

+
+
+ } + } +} diff --git a/candle-wasm-examples/yolo/src/bin/app.rs b/candle-wasm-examples/yolo/src/bin/app.rs new file mode 100644 index 0000000000..fbfffd8040 --- /dev/null +++ b/candle-wasm-examples/yolo/src/bin/app.rs @@ -0,0 +1,5 @@ +fn main() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); + console_error_panic_hook::set_once(); + yew::Renderer::::new().render(); +} diff --git a/candle-wasm-examples/yolo/src/bin/worker.rs b/candle-wasm-examples/yolo/src/bin/worker.rs new file mode 100644 index 0000000000..d1613e6827 --- /dev/null +++ b/candle-wasm-examples/yolo/src/bin/worker.rs @@ -0,0 +1,5 @@ +use yew_agent::PublicWorker; +fn main() { + console_error_panic_hook::set_once(); + candle_wasm_example_yolo::Worker::register(); +} diff --git a/candle-wasm-examples/yolo/src/coco_classes.rs b/candle-wasm-examples/yolo/src/coco_classes.rs new file mode 100644 index 0000000000..0075352492 --- /dev/null +++ b/candle-wasm-examples/yolo/src/coco_classes.rs @@ -0,0 +1,82 @@ +pub const NAMES: [&str; 80] = [ + "person", + "bicycle", + "car", + "motorbike", + "aeroplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "sofa", + "pottedplant", + "bed", + "diningtable", + "toilet", + "tvmonitor", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +]; diff --git a/candle-wasm-examples/yolo/src/lib.rs b/candle-wasm-examples/yolo/src/lib.rs new file mode 100644 index 0000000000..76af1d631c --- /dev/null +++ b/candle-wasm-examples/yolo/src/lib.rs @@ -0,0 +1,6 @@ +mod app; +mod coco_classes; +mod model; +mod worker; +pub use app::App; +pub use worker::Worker; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs new file mode 100644 index 0000000000..184045f0a8 --- /dev/null +++ b/candle-wasm-examples/yolo/src/model.rs @@ -0,0 +1,684 @@ +#![allow(dead_code)] +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder, +}; +use image::DynamicImage; + +const CONFIDENCE_THRESHOLD: f32 = 0.5; +const NMS_THRESHOLD: f32 = 0.4; + +// Model architecture from https://github.com/ultralytics/ultralytics/issues/189 +// https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py + +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct Multiples { + depth: f64, + width: f64, + ratio: f64, +} + +impl Multiples { + pub fn n() -> Self { + Self { + depth: 0.33, + width: 0.25, + ratio: 2.0, + } + } + pub fn s() -> Self { + Self { + depth: 0.33, + width: 0.50, + ratio: 2.0, + } + } + pub fn m() -> Self { + Self { + depth: 0.67, + width: 0.75, + ratio: 1.5, + } + } + pub fn l() -> Self { + Self { + depth: 1.00, + width: 1.00, + ratio: 1.0, + } + } + pub fn x() -> Self { + Self { + depth: 1.00, + width: 1.25, + ratio: 1.0, + } + } + + fn filters(&self) -> (usize, usize, usize) { + let f1 = (256. * self.width) as usize; + let f2 = (512. * self.width) as usize; + let f3 = (512. * self.width * self.ratio) as usize; + (f1, f2, f3) + } +} + +#[derive(Debug)] +struct Upsample { + scale_factor: usize, +} + +impl Upsample { + fn new(scale_factor: usize) -> Result { + Ok(Upsample { scale_factor }) + } +} + +impl Module for Upsample { + fn forward(&self, xs: &Tensor) -> candle::Result { + let (_b_size, _channels, h, w) = xs.dims4()?; + xs.upsample_nearest2d(self.scale_factor * h, self.scale_factor * w) + } +} + +#[derive(Debug)] +struct ConvBlock { + conv: Conv2d, + bn: BatchNorm, +} + +impl ConvBlock { + fn load( + vb: VarBuilder, + c1: usize, + c2: usize, + k: usize, + stride: usize, + padding: Option, + ) -> Result { + let padding = padding.unwrap_or(k / 2); + let cfg = Conv2dConfig { padding, stride }; + let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; + let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; + Ok(Self { conv, bn }) + } +} + +impl Module for ConvBlock { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.conv.forward(xs)?; + let xs = self.bn.forward(&xs)?; + candle_nn::ops::silu(&xs) + } +} + +#[derive(Debug)] +struct Bottleneck { + cv1: ConvBlock, + cv2: ConvBlock, + residual: bool, +} + +impl Bottleneck { + fn load(vb: VarBuilder, c1: usize, c2: usize, shortcut: bool) -> Result { + let channel_factor = 1.; + let c_ = (c2 as f64 * channel_factor) as usize; + let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?; + let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?; + let residual = c1 == c2 && shortcut; + Ok(Self { cv1, cv2, residual }) + } +} + +impl Module for Bottleneck { + fn forward(&self, xs: &Tensor) -> Result { + let ys = self.cv2.forward(&self.cv1.forward(xs)?)?; + if self.residual { + xs + ys + } else { + Ok(ys) + } + } +} + +#[derive(Debug)] +struct C2f { + cv1: ConvBlock, + cv2: ConvBlock, + bottleneck: Vec, +} + +impl C2f { + fn load(vb: VarBuilder, c1: usize, c2: usize, n: usize, shortcut: bool) -> Result { + let c = (c2 as f64 * 0.5) as usize; + let cv1 = ConvBlock::load(vb.pp("cv1"), c1, 2 * c, 1, 1, None)?; + let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?; + let mut bottleneck = Vec::with_capacity(n); + for idx in 0..n { + let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?; + bottleneck.push(b) + } + Ok(Self { + cv1, + cv2, + bottleneck, + }) + } +} + +impl Module for C2f { + fn forward(&self, xs: &Tensor) -> Result { + let ys = self.cv1.forward(xs)?; + let mut ys = ys.chunk(2, 1)?; + for m in self.bottleneck.iter() { + ys.push(m.forward(ys.last().unwrap())?) + } + let zs = Tensor::cat(ys.as_slice(), 1)?; + self.cv2.forward(&zs) + } +} + +#[derive(Debug)] +struct Sppf { + cv1: ConvBlock, + cv2: ConvBlock, + k: usize, +} + +impl Sppf { + fn load(vb: VarBuilder, c1: usize, c2: usize, k: usize) -> Result { + let c_ = c1 / 2; + let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?; + let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?; + Ok(Self { cv1, cv2, k }) + } +} + +impl Module for Sppf { + fn forward(&self, xs: &Tensor) -> Result { + let (_, _, _, _) = xs.dims4()?; + let xs = self.cv1.forward(xs)?; + let xs2 = xs + .pad_with_zeros(2, self.k / 2, self.k / 2)? + .pad_with_zeros(3, self.k / 2, self.k / 2)? + .max_pool2d((self.k, self.k), (1, 1))?; + let xs3 = xs2 + .pad_with_zeros(2, self.k / 2, self.k / 2)? + .pad_with_zeros(3, self.k / 2, self.k / 2)? + .max_pool2d((self.k, self.k), (1, 1))?; + let xs4 = xs3 + .pad_with_zeros(2, self.k / 2, self.k / 2)? + .pad_with_zeros(3, self.k / 2, self.k / 2)? + .max_pool2d((self.k, self.k), (1, 1))?; + self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?) + } +} + +#[derive(Debug)] +struct Dfl { + conv: Conv2d, + num_classes: usize, +} + +impl Dfl { + fn load(vb: VarBuilder, num_classes: usize) -> Result { + let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?; + Ok(Self { conv, num_classes }) + } +} + +impl Module for Dfl { + fn forward(&self, xs: &Tensor) -> Result { + let (b_sz, _channels, anchors) = xs.dims3()?; + let xs = xs + .reshape((b_sz, 4, self.num_classes, anchors))? + .transpose(2, 1)?; + let xs = candle_nn::ops::softmax(&xs, 1)?; + self.conv.forward(&xs)?.reshape((b_sz, 4, anchors)) + } +} + +#[derive(Debug)] +struct DarkNet { + b1_0: ConvBlock, + b1_1: ConvBlock, + b2_0: C2f, + b2_1: ConvBlock, + b2_2: C2f, + b3_0: ConvBlock, + b3_1: C2f, + b4_0: ConvBlock, + b4_1: C2f, + b5: Sppf, +} + +impl DarkNet { + fn load(vb: VarBuilder, m: Multiples) -> Result { + let (w, r, d) = (m.width, m.ratio, m.depth); + let b1_0 = ConvBlock::load(vb.pp("b1.0"), 3, (64. * w) as usize, 3, 2, Some(1))?; + let b1_1 = ConvBlock::load( + vb.pp("b1.1"), + (64. * w) as usize, + (128. * w) as usize, + 3, + 2, + Some(1), + )?; + let b2_0 = C2f::load( + vb.pp("b2.0"), + (128. * w) as usize, + (128. * w) as usize, + (3. * d).round() as usize, + true, + )?; + let b2_1 = ConvBlock::load( + vb.pp("b2.1"), + (128. * w) as usize, + (256. * w) as usize, + 3, + 2, + Some(1), + )?; + let b2_2 = C2f::load( + vb.pp("b2.2"), + (256. * w) as usize, + (256. * w) as usize, + (6. * d).round() as usize, + true, + )?; + let b3_0 = ConvBlock::load( + vb.pp("b3.0"), + (256. * w) as usize, + (512. * w) as usize, + 3, + 2, + Some(1), + )?; + let b3_1 = C2f::load( + vb.pp("b3.1"), + (512. * w) as usize, + (512. * w) as usize, + (6. * d).round() as usize, + true, + )?; + let b4_0 = ConvBlock::load( + vb.pp("b4.0"), + (512. * w) as usize, + (512. * w * r) as usize, + 3, + 2, + Some(1), + )?; + let b4_1 = C2f::load( + vb.pp("b4.1"), + (512. * w * r) as usize, + (512. * w * r) as usize, + (3. * d).round() as usize, + true, + )?; + let b5 = Sppf::load( + vb.pp("b5.0"), + (512. * w * r) as usize, + (512. * w * r) as usize, + 5, + )?; + Ok(Self { + b1_0, + b1_1, + b2_0, + b2_1, + b2_2, + b3_0, + b3_1, + b4_0, + b4_1, + b5, + }) + } + + fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?; + let x2 = self + .b2_2 + .forward(&self.b2_1.forward(&self.b2_0.forward(&x1)?)?)?; + let x3 = self.b3_1.forward(&self.b3_0.forward(&x2)?)?; + let x4 = self.b4_1.forward(&self.b4_0.forward(&x3)?)?; + let x5 = self.b5.forward(&x4)?; + Ok((x2, x3, x5)) + } +} + +#[derive(Debug)] +struct YoloV8Neck { + up: Upsample, + n1: C2f, + n2: C2f, + n3: ConvBlock, + n4: C2f, + n5: ConvBlock, + n6: C2f, +} + +impl YoloV8Neck { + fn load(vb: VarBuilder, m: Multiples) -> Result { + let up = Upsample::new(2)?; + let (w, r, d) = (m.width, m.ratio, m.depth); + let n = (3. * d).round() as usize; + let n1 = C2f::load( + vb.pp("n1"), + (512. * w * (1. + r)) as usize, + (512. * w) as usize, + n, + false, + )?; + let n2 = C2f::load( + vb.pp("n2"), + (768. * w) as usize, + (256. * w) as usize, + n, + false, + )?; + let n3 = ConvBlock::load( + vb.pp("n3"), + (256. * w) as usize, + (256. * w) as usize, + 3, + 2, + Some(1), + )?; + let n4 = C2f::load( + vb.pp("n4"), + (768. * w) as usize, + (512. * w) as usize, + n, + false, + )?; + let n5 = ConvBlock::load( + vb.pp("n5"), + (512. * w) as usize, + (512. * w) as usize, + 3, + 2, + Some(1), + )?; + let n6 = C2f::load( + vb.pp("n6"), + (512. * w * (1. + r)) as usize, + (512. * w * r) as usize, + n, + false, + )?; + Ok(Self { + up, + n1, + n2, + n3, + n4, + n5, + n6, + }) + } + + fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let x = self + .n1 + .forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?; + let head_1 = self + .n2 + .forward(&Tensor::cat(&[&self.up.forward(&x)?, p3], 1)?)?; + let head_2 = self + .n4 + .forward(&Tensor::cat(&[&self.n3.forward(&head_1)?, &x], 1)?)?; + let head_3 = self + .n6 + .forward(&Tensor::cat(&[&self.n5.forward(&head_2)?, p5], 1)?)?; + Ok((head_1, head_2, head_3)) + } +} + +#[derive(Debug)] +struct DetectionHead { + dfl: Dfl, + cv2: [(ConvBlock, ConvBlock, Conv2d); 3], + cv3: [(ConvBlock, ConvBlock, Conv2d); 3], + ch: usize, + no: usize, +} + +fn make_anchors( + xs0: &Tensor, + xs1: &Tensor, + xs2: &Tensor, + (s0, s1, s2): (usize, usize, usize), + grid_cell_offset: f64, +) -> Result<(Tensor, Tensor)> { + let dev = xs0.device(); + let mut anchor_points = vec![]; + let mut stride_tensor = vec![]; + for (xs, stride) in [(xs0, s0), (xs1, s1), (xs2, s2)] { + // xs is only used to extract the h and w dimensions. + let (_, _, h, w) = xs.dims4()?; + let sx = (Tensor::arange(0, w as u32, dev)?.to_dtype(DType::F32)? + grid_cell_offset)?; + let sy = (Tensor::arange(0, h as u32, dev)?.to_dtype(DType::F32)? + grid_cell_offset)?; + let sx = sx + .reshape((1, sx.elem_count()))? + .repeat((h, 1))? + .flatten_all()?; + let sy = sy + .reshape((sy.elem_count(), 1))? + .repeat((1, w))? + .flatten_all()?; + anchor_points.push(Tensor::stack(&[&sx, &sy], D::Minus1)?); + stride_tensor.push((Tensor::ones(h * w, DType::F32, dev)? * stride as f64)?); + } + let anchor_points = Tensor::cat(anchor_points.as_slice(), 0)?; + let stride_tensor = Tensor::cat(stride_tensor.as_slice(), 0)?.unsqueeze(1)?; + Ok((anchor_points, stride_tensor)) +} +fn dist2bbox(distance: &Tensor, anchor_points: &Tensor) -> Result { + let chunks = distance.chunk(2, 1)?; + let lt = &chunks[0]; + let rb = &chunks[1]; + let x1y1 = anchor_points.sub(lt)?; + let x2y2 = anchor_points.add(rb)?; + let c_xy = ((&x1y1 + &x2y2)? * 0.5)?; + let wh = (&x2y2 - &x1y1)?; + Tensor::cat(&[c_xy, wh], 1) +} + +impl DetectionHead { + fn load(vb: VarBuilder, nc: usize, filters: (usize, usize, usize)) -> Result { + let ch = 16; + let dfl = Dfl::load(vb.pp("dfl"), ch)?; + let c1 = usize::max(filters.0, nc); + let c2 = usize::max(filters.0 / 4, ch * 4); + let cv3 = [ + Self::load_cv3(vb.pp("cv3.0"), c1, nc, filters.0)?, + Self::load_cv3(vb.pp("cv3.1"), c1, nc, filters.1)?, + Self::load_cv3(vb.pp("cv3.2"), c1, nc, filters.2)?, + ]; + let cv2 = [ + Self::load_cv2(vb.pp("cv2.0"), c2, ch, filters.0)?, + Self::load_cv2(vb.pp("cv2.1"), c2, ch, filters.1)?, + Self::load_cv2(vb.pp("cv2.2"), c2, ch, filters.2)?, + ]; + let no = nc + ch * 4; + Ok(Self { + dfl, + cv2, + cv3, + ch, + no, + }) + } + + fn load_cv3( + vb: VarBuilder, + c1: usize, + nc: usize, + filter: usize, + ) -> Result<(ConvBlock, ConvBlock, Conv2d)> { + let block0 = ConvBlock::load(vb.pp("0"), filter, c1, 3, 1, None)?; + let block1 = ConvBlock::load(vb.pp("1"), c1, c1, 3, 1, None)?; + let conv = conv2d(c1, nc, 1, Default::default(), vb.pp("2"))?; + Ok((block0, block1, conv)) + } + + fn load_cv2( + vb: VarBuilder, + c2: usize, + ch: usize, + filter: usize, + ) -> Result<(ConvBlock, ConvBlock, Conv2d)> { + let block0 = ConvBlock::load(vb.pp("0"), filter, c2, 3, 1, None)?; + let block1 = ConvBlock::load(vb.pp("1"), c2, c2, 3, 1, None)?; + let conv = conv2d(c2, 4 * ch, 1, Default::default(), vb.pp("2"))?; + Ok((block0, block1, conv)) + } + + fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result { + let forward_cv = |xs, i: usize| { + let xs_2 = self.cv2[i].0.forward(xs)?; + let xs_2 = self.cv2[i].1.forward(&xs_2)?; + let xs_2 = self.cv2[i].2.forward(&xs_2)?; + + let xs_3 = self.cv3[i].0.forward(xs)?; + let xs_3 = self.cv3[i].1.forward(&xs_3)?; + let xs_3 = self.cv3[i].2.forward(&xs_3)?; + Tensor::cat(&[&xs_2, &xs_3], 1) + }; + let xs0 = forward_cv(xs0, 0)?; + let xs1 = forward_cv(xs1, 1)?; + let xs2 = forward_cv(xs2, 2)?; + + let (anchors, strides) = make_anchors(&xs0, &xs1, &xs2, (8, 16, 32), 0.5)?; + let anchors = anchors.transpose(0, 1)?; + let strides = strides.transpose(0, 1)?; + + let reshape = |xs: &Tensor| { + let d = xs.dim(0)?; + let el = xs.elem_count(); + xs.reshape((d, self.no, el / (d * self.no))) + }; + let ys0 = reshape(&xs0)?; + let ys1 = reshape(&xs1)?; + let ys2 = reshape(&xs2)?; + + let x_cat = Tensor::cat(&[ys0, ys1, ys2], 2)?; + let box_ = x_cat.i((.., ..self.ch * 4))?; + let cls = x_cat.i((.., self.ch * 4..))?; + + let dbox = dist2bbox(&self.dfl.forward(&box_)?, &anchors.unsqueeze(0)?)?; + let dbox = dbox.broadcast_mul(&strides)?; + Tensor::cat(&[dbox, candle_nn::ops::sigmoid(&cls)?], 1) + } +} + +#[derive(Debug)] +pub struct YoloV8 { + net: DarkNet, + fpn: YoloV8Neck, + head: DetectionHead, +} + +impl YoloV8 { + pub fn load(vb: VarBuilder, m: Multiples, num_classes: usize) -> Result { + let net = DarkNet::load(vb.pp("net"), m)?; + let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?; + let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?; + Ok(Self { net, fpn, head }) + } +} + +impl Module for YoloV8 { + fn forward(&self, xs: &Tensor) -> Result { + let (xs1, xs2, xs3) = self.net.forward(xs)?; + let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?; + self.head.forward(&xs1, &xs2, &xs3) + } +} + +#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)] +pub struct Bbox { + pub xmin: f32, + pub ymin: f32, + pub xmax: f32, + pub ymax: f32, + pub confidence: f32, +} + +// Intersection over union of two bounding boxes. +fn iou(b1: &Bbox, b2: &Bbox) -> f32 { + let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); + let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); + let i_xmin = b1.xmin.max(b2.xmin); + let i_xmax = b1.xmax.min(b2.xmax); + let i_ymin = b1.ymin.max(b2.ymin); + let i_ymax = b1.ymax.min(b2.ymax); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + i_area / (b1_area + b2_area - i_area) +} + +pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result>> { + let (pred_size, npreds) = pred.dims2()?; + let nclasses = pred_size - 4; + // The bounding boxes grouped by (maximum) class index. + let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); + // Extract the bounding boxes for which confidence is above the threshold. + for index in 0..npreds { + let pred = Vec::::try_from(pred.i((.., index))?)?; + let confidence = *pred[4..].iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + if confidence > CONFIDENCE_THRESHOLD { + let mut class_index = 0; + for i in 0..nclasses { + if pred[4 + i] > pred[4 + class_index] { + class_index = i + } + } + if pred[class_index + 4] > 0. { + let bbox = Bbox { + xmin: pred[0] - pred[2] / 2., + ymin: pred[1] - pred[3] / 2., + xmax: pred[0] + pred[2] / 2., + ymax: pred[1] + pred[3] / 2., + confidence, + }; + bboxes[class_index].push(bbox) + } + } + } + // Perform non-maximum suppression. + for bboxes_for_class in bboxes.iter_mut() { + bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); + let mut current_index = 0; + for index in 0..bboxes_for_class.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); + if iou > NMS_THRESHOLD { + drop = true; + break; + } + } + if !drop { + bboxes_for_class.swap(current_index, index); + current_index += 1; + } + } + bboxes_for_class.truncate(current_index); + } + // Annotate the original image and print boxes information. + let (initial_h, initial_w) = (img.height() as f32, img.width() as f32); + let w_ratio = initial_w / w as f32; + let h_ratio = initial_h / h as f32; + for (class_index, bboxes_for_class) in bboxes.iter_mut().enumerate() { + for b in bboxes_for_class.iter_mut() { + crate::console_log!("{}: {:?}", crate::coco_classes::NAMES[class_index], b); + b.xmin = (b.xmin * w_ratio).clamp(0., initial_w - 1.); + b.ymin = (b.ymin * h_ratio).clamp(0., initial_h - 1.); + b.xmax = (b.xmax * w_ratio).clamp(0., initial_w - 1.); + b.ymax = (b.ymax * h_ratio).clamp(0., initial_h - 1.); + } + } + Ok(bboxes) +} diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs new file mode 100644 index 0000000000..cc029cf822 --- /dev/null +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -0,0 +1,132 @@ +use crate::model::{report, Bbox, Multiples, YoloV8}; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; +use yew_agent::{HandlerId, Public, WorkerLink}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string())) +} + +// Communication to the worker happens through bincode, the model weights and configs are fetched +// on the main thread and transfered via the following structure. +#[derive(Serialize, Deserialize)] +pub struct ModelData { + pub weights: Vec, +} + +struct Model { + model: YoloV8, +} + +impl Model { + fn run( + &self, + _link: &WorkerLink, + _id: HandlerId, + image_data: Vec, + ) -> Result>> { + console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let original_image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let image = { + let data = original_image + .resize_exact(640, 640, image::imageops::FilterType::Triangle) + .to_rgb8() + .into_raw(); + Tensor::from_vec(data, (640, 640, 3), &Device::Cpu)?.permute((2, 0, 1))? + }; + let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; + let predictions = self.model.forward(&image)?.squeeze(0)?; + console_log!("generated predictions {predictions:?}"); + let bboxes = report(&predictions, original_image, 640, 640)?; + Ok(bboxes) + } +} + +impl Model { + fn load(md: ModelData) -> Result { + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let model = YoloV8::load(vb, Multiples::s(), 80)?; + Ok(Self { model }) + } +} + +pub struct Worker { + link: WorkerLink, + model: Option, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerInput { + ModelData(ModelData), + Run(Vec), +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerOutput { + ProcessingDone(std::result::Result>, String>), + WeightsLoaded, +} + +impl yew_agent::Worker for Worker { + type Input = WorkerInput; + type Message = (); + type Output = std::result::Result; + type Reach = Public; + + fn create(link: WorkerLink) -> Self { + Self { link, model: None } + } + + fn update(&mut self, _msg: Self::Message) { + // no messaging + } + + fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { + let output = match msg { + WorkerInput::ModelData(md) => match Model::load(md) { + Ok(model) => { + self.model = Some(model); + Ok(WorkerOutput::WeightsLoaded) + } + Err(err) => Err(format!("model creation error {err:?}")), + }, + WorkerInput::Run(image_data) => match &mut self.model { + None => Err("model has not been set yet".to_string()), + Some(model) => { + let result = model + .run(&self.link, id, image_data) + .map_err(|e| e.to_string()); + Ok(WorkerOutput::ProcessingDone(result)) + } + }, + }; + self.link.respond(id, output); + } + + fn name_of_resource() -> &'static str { + "worker.js" + } + + fn resource_path_is_relative() -> bool { + true + } +}