From f59e2f7afad4a06d591402fff1902f7806e09251 Mon Sep 17 00:00:00 2001 From: Dmitry Sharshakov Date: Fri, 6 May 2022 20:57:08 +0300 Subject: [PATCH 1/2] Move code in from tract repo Ported GPUTensor handling code --- Cargo.lock | 249 +++++++++++++++++++++++++-- Cargo.toml | 3 +- wonnx-tract/Cargo.toml | 42 +++++ wonnx-tract/examples/a.rs | 27 +++ wonnx-tract/shaders/sigmoid.wgsl | 31 ++++ wonnx-tract/shaders/tanh.wgsl | 30 ++++ wonnx-tract/src/lib.rs | 277 +++++++++++++++++++++++++++++++ 7 files changed, 644 insertions(+), 15 deletions(-) create mode 100644 wonnx-tract/Cargo.toml create mode 100644 wonnx-tract/examples/a.rs create mode 100644 wonnx-tract/shaders/sigmoid.wgsl create mode 100644 wonnx-tract/shaders/tanh.wgsl create mode 100644 wonnx-tract/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 7d9f78c2..ed577894 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,6 +305,15 @@ dependencies = [ "zip-extensions", ] +[[package]] +name = "cast" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c24dab4283a142afa2fdca129b80ad2c6284e073930f964c3a1293c225ee39a" +dependencies = [ + "rustc_version", +] + [[package]] name = "cc" version = "1.0.73" @@ -447,6 +456,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1604dafd25fba2fe2d5895a9da139f8dc9b319a5fe5354ca137cbbce4e178d10" +dependencies = [ + "atty", + "cast", + "clap", + "criterion-plot", + "csv", + "itertools 0.10.3", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d00996de9f2f7559f7f4dc286073197f83e92256a59ed395f9aac01fe717da57" +dependencies = [ + "cast", + "itertools 0.10.3", +] + [[package]] name = "crossbeam-channel" version = "0.5.4" @@ -879,6 +924,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "futures" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.21" @@ -886,6 +946,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -894,12 +955,34 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +[[package]] +name = "futures-executor" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" +[[package]] +name = "futures-macro" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.21" @@ -918,8 +1001,11 @@ version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", + "futures-sink", "futures-task", "memchr", "pin-project-lite", @@ -1906,6 +1992,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + [[package]] name = "opaque-debug" version = "0.2.3" @@ -2123,6 +2215,34 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +[[package]] +name = "plotters" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a3fd9ec30b9749ce28cd91f255d569591cdf937fe280c312143e3c4bad6f2a" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d88417318da0eaf0fdcdb51a0ee6c3bed624333bff8f946733049380be67ac1c" + +[[package]] +name = "plotters-svg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521fa9638fa597e1dc53e9412a4f9cefb01187ee1f7413076f9e6749e2885ba9" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.5" @@ -2257,6 +2377,26 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9145ac0af1d93c638c98c40cf7d25665f427b2a44ad0a99b1dccf3e2f25bb987" +[[package]] +name = "proptest" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0d9cc07f18492d879586c92b485def06bc850da3118075cd45d50e9c95b0e5" +dependencies = [ + "bit-set", + "bitflags", + "byteorder", + "lazy_static", + "num-traits", + "quick-error 2.0.1", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", +] + [[package]] name = "prost" version = "0.9.0" @@ -2378,6 +2518,18 @@ dependencies = [ "syn", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + [[package]] name = "quote" version = "1.0.18" @@ -2458,6 +2610,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core 0.6.3", +] + [[package]] name = "range-alloc" version = "0.1.2" @@ -2652,6 +2813,18 @@ dependencies = [ "semver", ] +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error 1.2.3", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.9" @@ -2742,6 +2915,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.136" @@ -3084,6 +3267,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792" +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.5.1" @@ -3212,9 +3405,9 @@ dependencies = [ [[package]] name = "tract-core" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f2d8ee91a9ad22f260e41fd434acb1361055d562486ae270c271991b3879f8" +checksum = "7ae63f5e21145b87b7812ac050f692ebd9fae5c31146c37e03f07b1a31ff7ab4" dependencies = [ "anyhow", "bit-set", @@ -3235,9 +3428,9 @@ dependencies = [ [[package]] name = "tract-data" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3d76e0765c8f971ab71a91b04ab71510828e782850efeb550f6ca45ccb0cfc" +checksum = "9aeb135db1d5d5824666edf10be35b01c8a84bf993cc64a5321bc13070624840" dependencies = [ "anyhow", "educe", @@ -3254,9 +3447,9 @@ dependencies = [ [[package]] name = "tract-hir" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fecd4dcef42cb54f1226ad00f4db2919a4362a08e982df8bfb0e51fd4c067f" +checksum = "dd9b6e42023e0949688b47f512e36f0d4691e6ee7ddb014bc514ea03d2d5122f" dependencies = [ "derive-new", "educe", @@ -3266,9 +3459,9 @@ dependencies = [ [[package]] name = "tract-linalg" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0203e1a40d1027fba6047310ee5678345b4e301379c641dd5e65136a16e9b21f" +checksum = "a98190bce172d691ee230d423e5881cbb546b61df768f521611c51d52511d660" dependencies = [ "cc", "derive-new", @@ -3290,9 +3483,9 @@ dependencies = [ [[package]] name = "tract-nnef" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c01c6d3c5a5e64d3964f18631d804d1cc17ba15171772da4761b2dfb6c92acfd" +checksum = "3b2cc0350faa927c1d895e06ed28cebf4e4e9437642aa5ac65942462006ab342" dependencies = [ "byteorder", "flate2", @@ -3305,9 +3498,9 @@ dependencies = [ [[package]] name = "tract-onnx" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904c41c877a0ddd3c05a9bae864069dda1e61d75618e6e7a8e1887bf55bb6adf" +checksum = "9a904657135f9baeffc0f6ef6eca0a805788e82f517e956f2c632e6238c88594" dependencies = [ "bytes", "derive-new", @@ -3325,9 +3518,9 @@ dependencies = [ [[package]] name = "tract-onnx-opl" -version = "0.16.4" +version = "0.16.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68324435532fe72b8455de8c9f2c237fd34d5e0e9f0456b12a8a41075ca11f22" +checksum = "22d3210700af398abe3279e7116cdf0fe09b7726061fc9bded590166f0c3d07a" dependencies = [ "educe", "tract-nnef", @@ -3903,6 +4096,34 @@ dependencies = [ "wonnx", ] +[[package]] +name = "wonnx-tract" +version = "0.0.1" +dependencies = [ + "bytemuck", + "cc", + "criterion", + "derive-new", + "downcast-rs", + "dyn-clone", + "educe", + "futures", + "lazy_static", + "libc", + "liquid", + "log", + "num-traits", + "paste", + "proptest", + "scan_fmt", + "smallvec", + "tract-core", + "tract-data", + "unicode-normalization", + "walkdir", + "wgpu", +] + [[package]] name = "wonnx-wasm" version = "0.2.4" diff --git a/Cargo.toml b/Cargo.toml index 7fcfd823..1e341e8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ members = [ "wonnx-py", "wonnx-cli", "wonnx-preprocessing", - "wonnx-wasm" + "wonnx-wasm", + "wonnx-tract" ] default-members = ["wonnx", "wonnx-cli"] diff --git a/wonnx-tract/Cargo.toml b/wonnx-tract/Cargo.toml new file mode 100644 index 00000000..720c6b17 --- /dev/null +++ b/wonnx-tract/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "wonnx-tract" +version = "0.0.1" +license = "MIT/Apache-2.0" +authors = ["Dmitry Sharshakov "] +description = "An integration crate for using GPU compiler from wonnx to accelerate tract model inference" +repository = "https://github.com/snipsco/tract" +keywords = ["TensorFlow", "NeuralNetworks"] +categories = ["science"] +autobenches = false +edition = "2018" + +[badges] +maintenance = { status = "actively-developed" } + +[dependencies] +derive-new = "0.5.9" +downcast-rs = "1.2.0" +dyn-clone = "1.0.4" +educe = "0.4.18" +lazy_static = "1.4.0" +libc = "0.2.100" +log = "0.4.14" +num-traits = "0.2.14" +tract-data = "0.16.6" +tract-core = "0.16.6" +paste = "1.0.5" +scan_fmt = "0.2.6" +wgpu = "0.12" +futures = "0.3" +bytemuck = "1.9" + +[build-dependencies] +cc = "1.0.69" +liquid = "0.24" +unicode-normalization = "0.1.19" +smallvec = "1.6.1" +walkdir = "2.3.2" + +[dev-dependencies] +criterion = "0.3.5" +proptest = "1.0.0" diff --git a/wonnx-tract/examples/a.rs b/wonnx-tract/examples/a.rs new file mode 100644 index 00000000..024180d5 --- /dev/null +++ b/wonnx-tract/examples/a.rs @@ -0,0 +1,27 @@ +use futures::executor::block_on; +use tract_core::prelude::{DatumType, Tensor}; +use tract_data::tvec; +use wonnx_tract::GpuAccel; + +fn main() { + let gpu = block_on(GpuAccel::default()).unwrap(); + + let x = 2; + let y = 2; + let z = 2; + let w = 2; + let mut data = Vec::new(); + for i in 1..(x * y * z * w + 1) { + data.push(i as f32); + } + + let inp = gpu + .import_tensor("inp".to_string(), &Tensor::from_shape(&tvec![x, y, z, w], &data).unwrap()); + let a = gpu.create_storage_tensor("a".to_string(), DatumType::F32, tvec![x, y, z, w]); + let out = gpu.create_out_tensor("out".to_string(), DatumType::F32, tvec![x, y, z, w]); + + gpu.tanh(&inp, &a); + gpu.sigmoid(&a, &out); + + println!("{:#?}", block_on(gpu.tensor_move_out(out)).dump(true)); +} diff --git a/wonnx-tract/shaders/sigmoid.wgsl b/wonnx-tract/shaders/sigmoid.wgsl new file mode 100644 index 00000000..186ba9b0 --- /dev/null +++ b/wonnx-tract/shaders/sigmoid.wgsl @@ -0,0 +1,31 @@ +struct Tensor { + shape: vec4; + strides: vec4; +}; + +struct Buffer { + data: [[stride(4)]] array; // 4 represents 4 bytes per value +}; + +[[group(0), binding(0)]] +var u_tensor: Tensor; + +[[group(0), binding(1)]] +var in: Buffer; + +[[group(0), binding(2)]] +var out: Buffer; + +[[stage(compute), workgroup_size(1)]] +fn main([[builtin(global_invocation_id)]] global_id: vec3) { + let id = global_id.x * u_tensor.strides.x + global_id.y * u_tensor.strides.y + global_id.z * u_tensor.strides.z; + var w: i32 = 0; + loop { + if (u32(w) >= u_tensor.shape.w) { + break; + } + out.data[id + u32(w) * u_tensor.strides.w] = 1.0 / (1.0 + exp(-1.0 * in.data[id + u32(w) * u_tensor.strides.w])); + w = w + 1; + } +} + diff --git a/wonnx-tract/shaders/tanh.wgsl b/wonnx-tract/shaders/tanh.wgsl new file mode 100644 index 00000000..4586037d --- /dev/null +++ b/wonnx-tract/shaders/tanh.wgsl @@ -0,0 +1,30 @@ +struct Tensor { + shape: vec4; + strides: vec4; +}; + +struct Buffer { + data: [[stride(4)]] array; // 4 represents 4 bytes per value +}; + +[[group(0), binding(0)]] +var u_tensor: Tensor; + +[[group(0), binding(1)]] +var in: Buffer; + +[[group(0), binding(2)]] +var out: Buffer; + +[[stage(compute), workgroup_size(1)]] +fn main([[builtin(global_invocation_id)]] global_id: vec3) { + let id = global_id.x * u_tensor.strides.x + global_id.y * u_tensor.strides.y + global_id.z * u_tensor.strides.z; + var w: i32 = 0; + loop { + if (u32(w) >= u_tensor.shape.w) { + break; + } + out.data[id + u32(w) * u_tensor.strides.w] = tanh(in.data[id + u32(w) * u_tensor.strides.w]); + w = w + 1; + } +} diff --git a/wonnx-tract/src/lib.rs b/wonnx-tract/src/lib.rs new file mode 100644 index 00000000..287e43e1 --- /dev/null +++ b/wonnx-tract/src/lib.rs @@ -0,0 +1,277 @@ +use std::borrow::Cow; +use std::fmt::Debug; +use tract_core::prelude::{natural_strides, DatumType, Tensor}; +use tract_data::TVec; +use wgpu::{ + util::{BufferInitDescriptor, DeviceExt}, + BindGroupDescriptor, BindGroupEntry, Buffer, BufferUsages, CommandEncoderDescriptor, + ComputePassDescriptor, ComputePipelineDescriptor, Device, DeviceDescriptor, Instance, Queue, + ShaderModule, ShaderModuleDescriptor, ShaderSource, +}; + +pub struct GPUTensor { + dt: DatumType, + shape: TVec, + strides: TVec, + len: usize, + info_uniform: Buffer, + buffer: Buffer, +} + +#[derive(Debug)] +pub struct GpuAccel { + device: Device, + queue: Queue, + sigmoid_shader: ShaderModule, + tanh_shader: ShaderModule, +} + +impl GpuAccel { + pub async fn default() -> Option { + let instance = + Instance::new(wgpu::util::backend_bits_from_env().unwrap_or_else(wgpu::Backends::all)); + let adapter = match wgpu::util::initialize_adapter_from_env_or_default( + &instance, + wgpu::util::backend_bits_from_env().unwrap_or_else(wgpu::Backends::all), + None, + ) + .await + { + Some(a) => a, + None => return None, + }; + + let (device, queue) = + match adapter.request_device(&DeviceDescriptor::default(), None).await.ok() { + Some((d, q)) => (d, q), + None => return None, + }; + + println!("Running inference on adapter: {:#?}", adapter.get_info()); + + let sigmoid_shader = device.create_shader_module(&ShaderModuleDescriptor { + label: None, + source: ShaderSource::Wgsl(Cow::Borrowed(include_str!("../shaders/sigmoid.wgsl"))), + }); + + let tanh_shader = device.create_shader_module(&ShaderModuleDescriptor { + label: None, + source: ShaderSource::Wgsl(Cow::Borrowed(include_str!("../shaders/tanh.wgsl"))), + }); + + Some(GpuAccel { device, queue, sigmoid_shader, tanh_shader }) + } + + pub fn alloc_in_buffer(&self, label: String, bytes: &Vec) -> Buffer { + self.device.create_buffer_init(&BufferInitDescriptor { + label: Some(&label), + contents: bytemuck::cast_slice(bytes), + usage: BufferUsages::STORAGE, + }) + } + + fn create_tensor_info_uniform( + &self, + shape: &TVec, + strides: &TVec, + label: String, + ) -> Buffer { + let mut tensor_info = vec![]; + tensor_info.push(*shape.get(0).unwrap_or(&1) as u32); + tensor_info.push(*shape.get(1).unwrap_or(&1) as u32); + tensor_info.push(*shape.get(2).unwrap_or(&1) as u32); + tensor_info.push(*shape.get(3).unwrap_or(&1) as u32); + tensor_info.push(*strides.get(0).unwrap_or(&0) as u32); + tensor_info.push(*strides.get(1).unwrap_or(&0) as u32); + tensor_info.push(*strides.get(2).unwrap_or(&0) as u32); + tensor_info.push(*strides.get(3).unwrap_or(&0) as u32); + self.device.create_buffer_init(&BufferInitDescriptor { + label: Some(&(label + "_info")), + contents: bytemuck::cast_slice(&tensor_info), + usage: BufferUsages::UNIFORM, + }) + } + + pub fn import_tensor(&self, label: String, t: &Tensor) -> GPUTensor { + let shape: TVec = t.shape().into(); + let strides: TVec = t.strides().iter().map(|x| *x as usize).collect(); + + unsafe { + GPUTensor { + dt: t.datum_type(), + shape: shape.clone(), + strides: strides.clone(), + len: t.len(), + info_uniform: self.create_tensor_info_uniform(&shape, &strides, label.clone()), + buffer: self.alloc_in_buffer(label, &Vec::from(t.as_bytes())), + } + } + } + + pub fn alloc_storage_buffer(&self, label: String, size: u64, output: bool) -> Buffer { + self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some(&label), + size, + usage: if output { + BufferUsages::STORAGE | BufferUsages::MAP_READ | BufferUsages::COPY_SRC + } else { + BufferUsages::STORAGE + }, + mapped_at_creation: false, + }) + } + + fn create_generic_storage_tensor( + &self, + label: String, + dt: DatumType, + shape: TVec, + output: bool, + ) -> GPUTensor { + let strides = natural_strides(&shape); + let len = if shape.len() == 0 { + 1 + } else { + *strides.get(0).unwrap() as usize * shape.get(0).unwrap() + }; + + let unsigned_strides: TVec = strides.iter().map(|x| *x as usize).collect(); + + GPUTensor { + dt, + shape: shape.clone(), + strides: strides.iter().map(|x| *x as usize).collect(), + len, + info_uniform: self.create_tensor_info_uniform(&shape, &unsigned_strides, label.clone()), + buffer: self.alloc_storage_buffer(label, len as u64 * dt.size_of() as u64, output), + } + } + + pub fn create_storage_tensor( + &self, + label: String, + dt: DatumType, + shape: TVec, + ) -> GPUTensor { + self.create_generic_storage_tensor(label, dt, shape, false) + } + + pub fn create_out_tensor(&self, label: String, dt: DatumType, shape: TVec) -> GPUTensor { + self.create_generic_storage_tensor(label, dt, shape, true) + } + + pub async fn buffer_move_out(&self, buf: Buffer) -> Vec { + let buffer_slice = buf.slice(..); + let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read); + + self.device.poll(wgpu::Maintain::Wait); + + if let Ok(()) = buffer_future.await { + let data = buffer_slice.get_mapped_range(); + let result = bytemuck::cast_slice(&data).to_vec(); + + drop(data); + buf.unmap(); + + return result; + } else { + panic!("Failed to move buffer {:?} out from GPU memory", buf); + } + } + + pub async fn tensor_move_out(&self, t: GPUTensor) -> Tensor { + unsafe { + Tensor::from_raw_dt(t.dt, &t.shape, &self.buffer_move_out::(t.buffer).await) + .unwrap() + } + } + + pub fn sigmoid(&self, in_tensor: &GPUTensor, out_tensor: &GPUTensor) { + if in_tensor.dt != out_tensor.dt || in_tensor.dt != DatumType::F32 { + panic!("Sigmoid only supports F32 tensors"); + } + if in_tensor.shape != out_tensor.shape { + panic!("Trying to do sigmoid between different tensor shapes"); + } + + let bind_group = vec![ + BindGroupEntry { binding: 0, resource: out_tensor.info_uniform.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: in_tensor.buffer.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: out_tensor.buffer.as_entire_binding() }, + ]; + + self.execute_shader( + &"sigmoid".to_string(), + &self.sigmoid_shader, + bind_group, + *out_tensor.shape.get(0).unwrap() as u32, + *out_tensor.shape.get(1).unwrap_or(&1) as u32, + *out_tensor.shape.get(2).unwrap_or(&1) as u32, + ); + } + + pub fn tanh(&self, in_tensor: &GPUTensor, out_tensor: &GPUTensor) { + if in_tensor.dt != out_tensor.dt || in_tensor.dt != DatumType::F32 { + panic!("Tanh only supports F32 tensors"); + } + if in_tensor.shape != out_tensor.shape { + panic!("Trying to do tanh between different tensor shapes"); + } + + let bind_group = vec![ + BindGroupEntry { binding: 0, resource: out_tensor.info_uniform.as_entire_binding() }, + BindGroupEntry { binding: 1, resource: in_tensor.buffer.as_entire_binding() }, + BindGroupEntry { binding: 2, resource: out_tensor.buffer.as_entire_binding() }, + ]; + + self.execute_shader( + &"tanh".to_string(), + &self.tanh_shader, + bind_group, + *out_tensor.shape.get(0).unwrap() as u32, + *out_tensor.shape.get(1).unwrap_or(&1) as u32, + *out_tensor.shape.get(2).unwrap_or(&1) as u32, + ); + } + + pub fn execute_shader( + &self, + label: &String, + shader: &ShaderModule, + bind_group_entries: Vec>, + wg_x: u32, + wg_y: u32, + wg_z: u32, + ) { + let compute_pipeline = self.device.create_compute_pipeline(&ComputePipelineDescriptor { + label: Some(&(label.clone() + "_pipeline")), + layout: None, + module: shader, + entry_point: "main", + }); + + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group = self.device.create_bind_group(&BindGroupDescriptor { + label: Some(&(label.clone() + "_bind_group")), + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + + let mut encoder = self.device.create_command_encoder(&CommandEncoderDescriptor { + label: Some(&(label.clone() + "_encoder")), + }); + { + let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some(&(label.clone() + "_compute_pass")), + }); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + // Number of cells to run, the (x,y,z) size of item being processed + cpass.dispatch(wg_x, wg_y, wg_z); + } + + self.queue.submit(Some(encoder.finish())); + + self.device.poll(wgpu::Maintain::Wait); + } +} From 8bbc65e551b898b71481a5a2f8cc359af844329c Mon Sep 17 00:00:00 2001 From: Dmitry Sharshakov Date: Thu, 19 May 2022 10:15:02 +0300 Subject: [PATCH 2/2] Import and dump model data in the example --- Cargo.lock | 41 +++++++++++++++++++++++++++++++++++++++ wonnx-tract/Cargo.toml | 1 + wonnx-tract/examples/a.rs | 22 +++++++++++++++++++-- 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed577894..2829eed3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3526,6 +3526,46 @@ dependencies = [ "tract-nnef", ] +[[package]] +name = "tract-pulse" +version = "0.16.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "111327b44f8f54612fdda06d8cc27b845593d022dbaffba97bbe3e143ef50e99" +dependencies = [ + "downcast-rs", + "lazy_static", + "tract-pulse-opl", +] + +[[package]] +name = "tract-pulse-opl" +version = "0.16.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4ac9848dd6e8cc0597f9aceabbd21c310f8fa6acf09e814de34a926c091f83a" +dependencies = [ + "downcast-rs", + "lazy_static", + "tract-nnef", +] + +[[package]] +name = "tract-tensorflow" +version = "0.16.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34df1924abf6f8c87435cdff09f4a7ae330b8f7652c3cb1017fd0fd26782d10" +dependencies = [ + "bytes", + "derive-new", + "educe", + "log", + "mapr", + "prost", + "prost-build", + "prost-types", + "tract-hir", + "tract-pulse", +] + [[package]] name = "try-lock" version = "0.2.3" @@ -4119,6 +4159,7 @@ dependencies = [ "smallvec", "tract-core", "tract-data", + "tract-tensorflow", "unicode-normalization", "walkdir", "wgpu", diff --git a/wonnx-tract/Cargo.toml b/wonnx-tract/Cargo.toml index 720c6b17..7b1d18ac 100644 --- a/wonnx-tract/Cargo.toml +++ b/wonnx-tract/Cargo.toml @@ -24,6 +24,7 @@ log = "0.4.14" num-traits = "0.2.14" tract-data = "0.16.6" tract-core = "0.16.6" +tract-tensorflow = "0.16.6" paste = "1.0.5" scan_fmt = "0.2.6" wgpu = "0.12" diff --git a/wonnx-tract/examples/a.rs b/wonnx-tract/examples/a.rs index 024180d5..f9e61a12 100644 --- a/wonnx-tract/examples/a.rs +++ b/wonnx-tract/examples/a.rs @@ -1,6 +1,7 @@ use futures::executor::block_on; use tract_core::prelude::{DatumType, Tensor}; use tract_data::tvec; +use tract_tensorflow::prelude::*; use wonnx_tract::GpuAccel; fn main() { @@ -15,8 +16,10 @@ fn main() { data.push(i as f32); } - let inp = gpu - .import_tensor("inp".to_string(), &Tensor::from_shape(&tvec![x, y, z, w], &data).unwrap()); + let inp = gpu.import_tensor( + "inp".to_string(), + &Tensor::from_shape(&tvec![x, y, z, w], &data).unwrap(), + ); let a = gpu.create_storage_tensor("a".to_string(), DatumType::F32, tvec![x, y, z, w]); let out = gpu.create_out_tensor("out".to_string(), DatumType::F32, tvec![x, y, z, w]); @@ -24,4 +27,19 @@ fn main() { gpu.sigmoid(&a, &out); println!("{:#?}", block_on(gpu.tensor_move_out(out)).dump(true)); + + let model = tract_tensorflow::tensorflow() + .model_for_path("mobilenet_v2_1.4_224_frozen.pb") + .unwrap() + .with_input_fact( + 0, + InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 224, 224, 3)), + ) + .unwrap() + .into_typed() + .unwrap() + .into_decluttered() + .unwrap(); + // GPU state + println!("{:#?}", SimplePlan::new(model.clone())); }