Skip to content

Commit

Permalink
feat(kernel-meta): 实现广播器
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 4, 2024
1 parent b4b1e1c commit 26fa5ea
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"graph-topo",
"computation",
"stack-calculator",
"runtimes/kernel-meta",
"runtimes/cpu",
"runtimes/nvidia",
]
Expand Down
12 changes: 12 additions & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@
mod data_type;

pub use data_type::{AsDataType, DataType};

/// 表示形状里的数值的数据类型。
#[allow(non_camel_case_types)]
pub type udim = u32;

/// 表示有符号的 [udim],例如用负数表示反向。
#[allow(non_camel_case_types)]
pub type sdim = i32;

/// 表示 [udim] 的差的数据类型。
#[allow(non_camel_case_types)]
pub type ddim = i16;
12 changes: 0 additions & 12 deletions computation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@ pub use graph::Graph;
pub use operator::*;
pub use tensor::*;

/// 表示形状里的数值的数据类型。
#[allow(non_camel_case_types)]
pub type udim = u32;

/// 表示有符号的 [udim],例如用负数表示反向。
#[allow(non_camel_case_types)]
pub type sdim = i32;

/// 表示 [udim] 的差的数据类型。
#[allow(non_camel_case_types)]
pub type ddim = i16;

/// 加载图。
pub fn load_graph(path: impl AsRef<Path>, name: impl AsRef<OsStr>) -> Graph {
let path = path.as_ref();
Expand Down
2 changes: 1 addition & 1 deletion computation/src/operator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{ddim, sdim, udim};
use common::{ddim, sdim, udim};
use smallvec::SmallVec;

/// 算子。
Expand Down
3 changes: 1 addition & 2 deletions computation/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::udim;
use common::DataType;
use common::{udim, DataType};
use std::{alloc::Layout, fmt, str::FromStr, sync::Arc};

/// 张量。
Expand Down
14 changes: 14 additions & 0 deletions runtimes/kernel-meta/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "kernel-meta"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
common = { path = "../../common" }
graph-topo = { path = "../../graph-topo" }
computation = { path = "../../computation" }

bitvec = "1.0"
117 changes: 117 additions & 0 deletions runtimes/kernel-meta/src/broadcaster.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use bitvec::vec::BitVec;
use common::udim;
use graph_topo::ucount;

/// 广播器,支持任意数量输入形状相互广播的优化表示。
#[derive(Clone, Debug)]
pub struct Broadcaster {
/// 所有输入的各维度步长,形如 `[[udim; n]; m]`。
///
/// - `n` 是压缩后的张量维度;
/// - `m` 是 `inputs_count + 1`;
strides: Vec<udim>,
/// 输入张量的数量。
inputs_count: ucount,
/// 输出张量的元素数量。
output_size: udim,
}

impl Broadcaster {
/// 从所有输入的形状构造广播器。
pub fn from_inputs_shape(mut inputs: Vec<&[udim]>) -> Self {
let mut state = BitVec::<usize>::repeat(false, inputs.len());
let mut factors = vec![1; inputs.len()];
let mut output_size = 1;
let mut strides = Vec::new();

loop {
let mut next = BitVec::<usize>::repeat(false, inputs.len());
let shape = match inputs
.iter_mut()
// 为所有 input 标号
.enumerate()
// 取出最后一维
.filter_map(|(i, input)| {
input.split_last().map(|(&dim, head)| {
*input = head;
next.set(i, dim != 1);
dim
})
})
// 更新形状
.fold(None, |acc, dim| match acc {
Some(1) | None => Some(dim),
Some(shape) => {
assert!(dim == 1 || dim == shape);
Some(shape)
}
}) {
Some(1) => continue,
Some(shape) => shape,
None => break,
};
if next != state {
state = next;
strides.resize(strides.len() + inputs.len() + 1, 0);

for ((state, factor), dim) in state
.iter()
.zip(factors.iter_mut())
.zip(strides.iter_mut().rev())
{
if *state {
*dim = *factor;
*factor *= shape;
}
}
} else {
for (state, factor) in state.iter().zip(factors.iter_mut()) {
if *state {
*factor *= shape;
}
}
}
output_size *= shape;
}
if strides.len() == inputs.len() + 1 && strides.iter().all(|&x| x == 1) {
strides.clear();
} else {
strides.reverse();
}

Self {
strides,
inputs_count: inputs.len() as _,
output_size,
}
}

/// 从输出元素序号定位输入元素序号。
pub fn locate(&self, mut k: udim, ans: &mut [udim]) {
debug_assert_eq!(ans.len(), self.inputs_count as usize);

let each = self.inputs_count as usize + 1;
for i in 0..self.strides.len() / each {
let dim = &self.strides[each * i..][..each];
let (div, dim) = dim.split_last().unwrap();

let quot = k / div;
k %= div;
for (ans, dim) in ans.iter_mut().zip(dim) {
*ans += dim * quot;
}
}
}

/// 输出张量的元素数量。
#[inline]
pub fn output_size(&self) -> udim {
self.output_size
}

/// 判断广播器是否表示需要广播。
#[inline]
pub fn need_broadcast(&self) -> bool {
!self.strides.is_empty()
}
}
7 changes: 7 additions & 0 deletions runtimes/kernel-meta/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//! kernel 语义的优化表示。
#![deny(warnings, missing_docs)]

mod broadcaster;

pub use broadcaster::Broadcaster;

0 comments on commit 26fa5ea

Please sign in to comment.