Skip to content

Commit

Permalink
feat(runtimes/nvidia): 开始调用 cuda driver 构图 api
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 2, 2024
1 parent fd02d68 commit a8eb0f0
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 47 deletions.
17 changes: 9 additions & 8 deletions runtimes/nvidia/src/driver/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ impl Drop for Graph {
}

impl Graph {
pub fn new(&self) -> Graph {
pub fn new() -> Self {
let mut graph: cuda::CUgraph = null_mut();
cuda::invoke!(cuGraphCreate(&mut graph, 0));
Graph {
Self {
graph,
first_node: null_mut(),
last_node: null_mut(),
Expand Down Expand Up @@ -140,12 +140,13 @@ fn params_memcpy3d(
len: usize,
ty: MemcpyType,
) -> cuda::CUDA_MEMCPY3D {
use cuda::CUmemorytype::*;
let mut ans = cuda::CUDA_MEMCPY3D {
srcXInBytes: 0,
srcY: 0,
srcZ: 0,
srcLOD: 0,
srcMemoryType: cuda::CUmemorytype_enum::CU_MEMORYTYPE_DEVICE,
srcMemoryType: CU_MEMORYTYPE_DEVICE,
srcHost: null_mut(),
srcDevice: 0,
srcArray: null_mut(),
Expand All @@ -156,7 +157,7 @@ fn params_memcpy3d(
dstY: 0,
dstZ: 0,
dstLOD: 0,
dstMemoryType: cuda::CUmemorytype_enum::CU_MEMORYTYPE_DEVICE,
dstMemoryType: CU_MEMORYTYPE_DEVICE,
dstHost: null_mut(),
dstDevice: 0,
dstArray: null_mut(),
Expand All @@ -173,19 +174,19 @@ fn params_memcpy3d(
ans.dstDevice = dst as _;
}
MemcpyType::H2H => {
ans.srcMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST;
ans.srcMemoryType = CU_MEMORYTYPE_HOST;
ans.srcHost = src as _;
ans.dstMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST;
ans.dstMemoryType = CU_MEMORYTYPE_HOST;
ans.dstHost = dst as _;
}
MemcpyType::H2D => {
ans.srcMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST;
ans.srcMemoryType = CU_MEMORYTYPE_HOST;
ans.srcHost = src as _;
ans.dstDevice = dst as _;
}
MemcpyType::D2H => {
ans.srcDevice = src as _;
ans.dstMemoryType = cuda::CUmemorytype_enum::CU_MEMORYTYPE_HOST;
ans.dstMemoryType = CU_MEMORYTYPE_HOST;
ans.dstHost = dst as _;
}
};
Expand Down
10 changes: 8 additions & 2 deletions runtimes/nvidia/src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ mod graph;
mod memory;
mod stream;

#[inline(always)]
pub(crate) fn init() {
bindings::invoke!(cuInit(0));
}

trait AsRaw<T> {
unsafe fn as_raw(&self) -> T;
}
Expand All @@ -29,6 +34,7 @@ trait WithCtx {
unsafe fn ctx(&self) -> bindings::CUcontext;
}

pub(crate) use graph::Graph;
pub(crate) use context::{Context, ContextGuard};
pub(crate) use device::devices;
pub(crate) use graph::{ExecutableGraph, Graph};
pub(crate) use memory::Blob;
pub(crate) use stream::Stream;
58 changes: 53 additions & 5 deletions runtimes/nvidia/src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,63 @@
use crate::driver;
use crate::driver::{self, ContextGuard};
use graph_topo::GraphTopo;
use stack_calculator::{flat, unidir, RealtimeCalculator};
use std::sync::Arc;

pub struct Graph {
graph: driver::Graph,
ctx: Arc<driver::Context>,
graph: driver::ExecutableGraph,
topology: GraphTopo,
edges: Vec<MemOffset>,
static_mem: driver::Blob,
stack: driver::Blob,
}

enum MemOffset {
Static(usize),
Stack(usize),
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(transparent)]
struct MemOffset(usize);

impl MemOffset {
const INVALID: MemOffset = MemOffset(usize::MAX);
}

impl Graph {
#[inline]
pub fn new(src: &computation::Graph, dev: usize) -> Self {
driver::devices()[dev]
.context()
.apply(|ctx| ctx.runtime_graph(src))
}

#[inline]
pub fn run(&self) {
self.ctx.apply(|ctx| {
let stream = ctx.stream();
unsafe { self.graph.launch_on(&stream) }
})
}
}

impl ContextGuard<'_> {
pub fn runtime_graph(&self, src: &computation::Graph) -> Graph {
let src = &src.0;

let mut flat = flat::RealtimeCalculator::default();
let mut unidir = unidir::RealtimeCalculator::default();

let mut edges = vec![MemOffset::INVALID; src.edges.len()];

driver::init();
let graph = driver::Graph::new();

let mut static_mem = self.malloc(flat.peak());

Graph {
ctx: self.clone_ctx(),
graph: graph.instantiate(self),
topology: src.topology.clone(),
edges,
static_mem,
stack: self.malloc(unidir.peak()),
}
}
}
4 changes: 2 additions & 2 deletions runtimes/nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#![cfg(detected_cuda)]

use graph_topo::GraphTopo;

mod driver;
mod graph;

pub use graph::Graph;
52 changes: 36 additions & 16 deletions stack-calculator/src/flat.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,57 @@
use crate::{align, Calculator};
use std::{alloc::Layout, collections::HashSet};
//! 平铺对象的栈计算器,包括一个实时的版本和一个非实时的版本。
use crate::RealtimeCalculator as _;
use std::{alloc::Layout, collections::HashSet, ops::Range};

/// 平铺对象的栈计算器。
pub struct FlatCalculator;
pub struct Calculator;

impl Calculator for FlatCalculator {
impl crate::Calculator for Calculator {
fn calculate(
self,
topology: &graph_topo::GraphTopo,
manager: &mut impl crate::Manager,
) -> usize {
let global_outputs = HashSet::<usize>::from_iter(topology.global_outputs());

let mut ans = 0;
let mut rt_cal = RealtimeCalculator::default();
for (i, _inputs, outputs) in topology {
for i in outputs {
if !global_outputs.contains(&i) {
manager.set_tensor_offset(i, put_obj(&mut ans, manager.tensor_layout(i)));
manager.set_tensor_offset(i, rt_cal.alloc(manager.tensor_layout(i)).start);
}
}
manager.set_workspace_offset(i, put_obj(&mut ans, manager.workspace_layout(i)));
manager.set_workspace_offset(i, rt_cal.alloc(manager.workspace_layout(i)).start);
}
ans
rt_cal.peak()
}
}

#[inline(always)]
fn put_obj(size: &mut usize, obj: Layout) -> usize {
if obj.size() == 0 {
*size
} else {
let offset = align(*size, obj.align());
*size = offset + obj.size();
offset
/// 实时的平铺对象的栈计算器。
#[derive(Default, Debug)]
pub struct RealtimeCalculator {
pos: usize,
}

impl crate::RealtimeCalculator for RealtimeCalculator {
fn alloc(&mut self, obj: Layout) -> Range<usize> {
if obj.size() == 0 {
return 0..0;
}

let start = crate::align(self.pos, obj.align());
self.pos = start + obj.size();

start..self.pos
}

#[inline]
fn free(&mut self, _range: Range<usize>) {
// Nothing to do.
}

#[inline]
fn peak(&self) -> usize {
self.pos
}
}
21 changes: 15 additions & 6 deletions stack-calculator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,30 @@
#![deny(warnings, missing_docs)]

mod flat;
mod unidir;
pub mod flat;
pub mod unidir;

use graph_topo::GraphTopo;
use std::alloc::Layout;

pub use flat::FlatCalculator;
pub use unidir::UnidirCalculator;
use std::{alloc::Layout, ops::Range};

/// 栈计算器。
pub trait Calculator {
/// 与 `manager` 交互,根据给定的图拓扑计算每个对象在栈上的偏移并返回栈容量需求。
fn calculate(self, topology: &GraphTopo, manager: &mut impl Manager) -> usize;
}

/// 实时栈计算器。
pub trait RealtimeCalculator {
/// 分配满足 `obj` 要求的空间。
fn alloc(&mut self, obj: Layout) -> Range<usize>;

/// 释放 `range` 范围内的空间。
fn free(&mut self, range: Range<usize>);

/// 获取栈空间的历史峰值。
fn peak(&self) -> usize;
}

/// 栈计算管理器。
pub trait Manager {
/// 获取张量的数量。
Expand Down
21 changes: 13 additions & 8 deletions stack-calculator/src/unidir.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{align, Calculator};
//! 单向扩容的栈计算器,包括一个实时的版本和一个非实时的版本。
use crate::RealtimeCalculator as _;
use std::{
alloc::Layout,
cmp::Ordering,
Expand All @@ -7,9 +9,9 @@ use std::{
};

/// 单向扩容的栈计算器。
pub struct UnidirCalculator;
pub struct Calculator;

impl Calculator for UnidirCalculator {
impl crate::Calculator for Calculator {
fn calculate(
self,
topology: &graph_topo::GraphTopo,
Expand Down Expand Up @@ -59,8 +61,9 @@ impl Calculator for UnidirCalculator {
}
}

/// 实时的单向扩容栈计算器。
#[derive(Default, Debug)]
struct RealtimeCalculator {
pub struct RealtimeCalculator {
used: usize,
peak: usize,

Expand All @@ -69,7 +72,7 @@ struct RealtimeCalculator {
free_tail_head: HashMap<usize, usize>,
}

impl RealtimeCalculator {
impl crate::RealtimeCalculator for RealtimeCalculator {
fn alloc(&mut self, obj: Layout) -> Range<usize> {
if obj.size() == 0 {
return 0..0;
Expand All @@ -79,7 +82,7 @@ impl RealtimeCalculator {
if let Some(&HeadTail(Range { start, end })) = self
.free_headtails
.range(HeadTail(0..obj.size())..)
.find(|&HeadTail(r)| r.end - align(r.start, obj.align()) >= obj.size())
.find(|&HeadTail(r)| r.end - crate::align(r.start, obj.align()) >= obj.size())
{
self.free_headtails.remove(&HeadTail(start..end));
self.free_head_tail.remove(&start);
Expand Down Expand Up @@ -126,10 +129,12 @@ impl RealtimeCalculator {
}

#[inline]
const fn peak(&self) -> usize {
fn peak(&self) -> usize {
self.peak
}
}

impl RealtimeCalculator {
#[inline]
fn insert(&mut self, start: usize, end: usize) {
if end > start {
Expand All @@ -141,7 +146,7 @@ impl RealtimeCalculator {

#[inline(always)]
const fn head_tail(start: usize, obj: Layout) -> (usize, usize) {
let head = align(start, obj.align());
let head = crate::align(start, obj.align());
(head, head + obj.size())
}
}
Expand Down

0 comments on commit a8eb0f0

Please sign in to comment.