Skip to content

Commit

Permalink
Merge pull request #17 from charles-r-earp/kernel-builder-type-state
Browse files Browse the repository at this point in the history
Kernel builder type state
  • Loading branch information
charles-r-earp authored Feb 19, 2024
2 parents e086048 + ca09e94 commit 0d431b8
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 44 deletions.
98 changes: 80 additions & 18 deletions krnl-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,22 +1441,58 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
}
};
let host_array_length_checks = kernel_meta.host_array_length_checks();
let kernel_builder_specialize_fn = if !kernel_desc.spec_descs.is_empty() {
let specialize = !kernel_desc.spec_descs.is_empty();
let specialized = [format_ident!("S")];
let specialized = if specialize {
specialized.as_ref()
} else {
&[]
};
let kernel_builder_phantom_data = if specialize {
quote! { S }
} else {
quote! { () }
};
let kernel_builder_build_generics = if specialize {
quote! {
<Specialized<true>>
}
} else {
TokenStream2::new()
};
let kernel_builder_specialize_fn = if specialize {
let spec_def_args = kernel_meta.spec_def_args();
let spec_args = kernel_meta.spec_args();
quote! {
/// Specializes the kernel.
#[allow(clippy::too_many_arguments)]
pub fn specialize(mut self, #spec_def_args) -> Self {
let inner = self.inner.specialize(&[#(#spec_args.into()),*]);
Self {
inner,
pub fn specialize(mut self, #spec_def_args) -> KernelBuilder<Specialized<true>> {
KernelBuilder {
inner: self.inner.specialize(&[#(#spec_args.into()),*]),
_m: PhantomData,
}
}
}
} else {
TokenStream2::new()
};
let needs_groups = !kernel_meta.itemwise;
let with_groups = [format_ident!("G")];
let with_groups = if needs_groups {
with_groups.as_ref()
} else {
&[]
};
let kernel_phantom_data = if needs_groups {
quote! { G }
} else {
quote! { () }
};
let kernel_dispatch_generics = if needs_groups {
quote! { <WithGroups<true>> }
} else {
TokenStream2::new()
};
let input_docs = {
let input_tokens_string = prettyplease::unparse(&syn::parse2(quote! {
#[kernel]
Expand Down Expand Up @@ -1494,10 +1530,21 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
buffer::{Slice, SliceMut},
device::Device,
scalar::ScalarType,
kernel::__private::{Kernel as KernelBase, KernelBuilder as KernelBuilderBase, KernelDesc, SliceDesc, SpecDesc, PushDesc, Safety, validate_kernel},
kernel::__private::{
Kernel as KernelBase,
KernelBuilder as KernelBuilderBase,
Specialized,
WithGroups,
KernelDesc,
SliceDesc,
SpecDesc,
PushDesc,
Safety,
validate_kernel
},
anyhow::format_err,
};
use ::std::sync::OnceLock;
use ::std::{sync::OnceLock, marker::PhantomData};
#[cfg(not(krnlc))]
#[doc(hidden)]
use __krnl::macros::__krnl_cache;
Expand All @@ -1509,9 +1556,11 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
/// Builder for creating a [`Kernel`].
///
/// See [`builder()`](builder).
pub struct KernelBuilder {
pub struct KernelBuilder #(<#specialized = Specialized<false>>)* {
#[doc(hidden)]
inner: KernelBuilderBase,
#[doc(hidden)]
_m: PhantomData<#kernel_builder_phantom_data>,
}

/// Creates a builder.
Expand All @@ -1533,64 +1582,77 @@ fn kernel_impl(item_tokens: TokenStream2) -> Result<TokenStream2> {
match builder {
Ok(inner) => Ok(KernelBuilder {
inner: inner.clone(),
_m: PhantomData,
}),
Err(err) => Err(format_err!("{err}")),
}
}

impl KernelBuilder {
impl #(<#specialized>)* KernelBuilder #(<#specialized>)* {
/// Threads per group.
///
/// Defaults to [`DeviceInfo::default_threads()`](DeviceInfo::default_threads).
pub fn with_threads(self, threads: u32) -> Self {
Self {
inner: self.inner.with_threads(threads),
_m: PhantomData,
}
}
#kernel_builder_specialize_fn
}

impl KernelBuilder #kernel_builder_build_generics {
/// Builds the kernel for `device`.
///
/// The kernel is cached, so subsequent calls to `.build()` with identical
/// builders (ie threads and spec constants) may avoid recompiling.
///
/// # Errors
/// - `device` doesn't have required features.
/// - The kernel requires [specialization](kernel#specialization), but `.specialize(..)` was not called.
/// - The kernel is not supported on `device`.
/// - [`DeviceLost`].
pub fn build(&self, device: Device) -> Result<Kernel> {
let inner = self.inner.build(device)?;
Ok(Kernel { inner })
Ok(Kernel {
inner: self.inner.build(device)?,
_m: PhantomData,
})
}
}

/// Kernel.
pub struct Kernel {
pub struct Kernel #(<#with_groups = WithGroups<false>>)* {
#[doc(hidden)]
inner: KernelBase,
#[doc(hidden)]
_m: PhantomData<#kernel_phantom_data>,
}

impl Kernel {
impl #(<#with_groups>)* Kernel #(<#with_groups>)* {
/// Threads per group.
pub fn threads(&self) -> u32 {
self.inner.threads()
}
/// Global threads to dispatch.
///
/// Implicitly declares groups by rounding up to the next multiple of threads.
pub fn with_global_threads(self, global_threads: u32) -> Self {
Self {
pub fn with_global_threads(self, global_threads: u32) -> Kernel #kernel_dispatch_generics {
Kernel {
inner: self.inner.with_global_threads(global_threads),
_m: PhantomData,
}
}
/// Groups to dispatch.
///
/// For item kernels, if not provided, is inferred based on item arguments.
pub fn with_groups(self, groups: u32) -> Self {
Self {
pub fn with_groups(self, groups: u32) -> Kernel #kernel_dispatch_generics {
Kernel {
inner: self.inner.with_groups(groups),
_m: PhantomData,
}
}
}

impl Kernel #kernel_dispatch_generics {
/// Dispatches the kernel.
///
/// - Waits for immutable access to slice arguments.
Expand Down
48 changes: 22 additions & 26 deletions src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,19 +377,22 @@ pub mod saxpy {
}
/// Kernel.
pub struct Kernel { /* .. */ }
pub struct Kernel<G = WithGroups<false>> { /* .. */ }
impl Kernel {
impl<G> Kernel<G> {
/// Threads per group.
pub fn threads(&self) -> u32;
/// Global threads to dispatch.
///
/// Implicitly declares groups by rounding up to the next multiple of threads.
pub fn with_global_threads(self, global_threads: u32) -> Self;
pub fn with_global_threads(self, global_threads: u32) -> Kernel<WithGroups<true>>;
/// Groups to dispatch.
///
/// For item kernels, if not provided, is inferred based on item arguments.
pub fn with_groups(self, groups: u32) -> Self;
pub fn with_groups(self, groups: u32) -> Kernel<WithGroups<true>>;
}
impl Kernel<WithGroups<true>> {
/// Dispatches the kernel.
///
/// - Waits for immutable access to slice arguments.
Expand Down Expand Up @@ -946,6 +949,8 @@ pub mod __private {
Ok(output)
}

pub enum Specialized<const S: bool> {}

#[cfg_attr(not(feature = "device"), allow(dead_code))]
#[derive(Clone)]
pub struct KernelBuilder {
Expand Down Expand Up @@ -990,7 +995,8 @@ pub mod __private {
}
}
pub fn specialize(self, spec_consts: &[ScalarElem]) -> Self {
assert_eq!(spec_consts.len(), self.desc.spec_descs.len());
debug_assert_eq!(spec_consts.len(), self.desc.spec_descs.len());
#[cfg(debug_assertions)]
for (spec_const, spec_desc) in
spec_consts.iter().copied().zip(self.desc.spec_descs.iter())
{
Expand Down Expand Up @@ -1021,26 +1027,13 @@ pub mod __private {
if threads > max_threads {
bail!("Kernel {name} threads {threads} is greater than max_threads {max_threads}!");
}
let spec_bytes = {
if !self.desc.spec_descs.is_empty() && self.spec_consts.is_empty() {
bail!("Kernel `{name}` must be specialized!");
}
debug_assert_eq!(self.spec_consts.len(), desc.spec_descs.len());
#[cfg(debug_assertions)]
{
for (spec_const, spec_desc) in
self.spec_consts.iter().zip(desc.spec_descs.iter())
{
assert_eq!(spec_const.scalar_type(), spec_desc.scalar_type);
}
}
self.spec_consts
.iter()
.flat_map(|x| x.as_bytes())
.copied()
.chain(threads.to_ne_bytes())
.collect()
};
let spec_bytes = self
.spec_consts
.iter()
.flat_map(|x| x.as_bytes())
.copied()
.chain(threads.to_ne_bytes())
.collect();
let key = KernelKey {
id: self.id,
spec_bytes,
Expand All @@ -1060,6 +1053,8 @@ pub mod __private {
}
}

pub enum WithGroups<const G: bool> {}

#[derive(Clone)]
pub struct Kernel {
#[cfg(feature = "device")]
Expand Down Expand Up @@ -1167,7 +1162,8 @@ pub mod __private {
let groups = items / threads + u32::from(items % threads != 0);
groups.min(max_groups)
} else {
bail!("Kernel `{kernel_name}` global_threads or groups not provided!");
#[cfg(debug_assertions)]
unreachable!("groups not provided!");
};
let debug_printf_panic = if info.debug_printf() {
Some(Arc::new(AtomicBool::default()))
Expand Down
85 changes: 85 additions & 0 deletions tests/krnlc-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,91 @@ use dry::macro_for;
use krnl::macros::module;
use paste::paste;

/**
```no_run
use krnl::macros::module;
#[module]
#[krnl(no_build)]
mod kernels {
use krnl::{macros::kernel, device::Device, anyhow::Result};
#[kernel]
fn specialization<const X: i32>() {}
fn test_specialization(device: Device) -> Result<()> {
specialization::builder()?.specialize(1).build(device)?;
Ok(())
}
}
```
```compile_fail
use krnl::macros::module;
#[module]
#[krnl(no_build)]
mod kernels {
use krnl::{macros::kernel, device::Device, anyhow::Result};
#[kernel]
fn specialization<const X: i32>() {}
fn test_specialization(device: Device) -> Result<()> {
specialization::builder()?.build(device)?;
Ok(())
}
}
```
*/
#[allow(dead_code)]
enum Specialization {}

/**
```no_run
use krnl::macros::module;
#[module]
#[krnl(no_build)]
mod kernels {
use krnl::{macros::kernel, device::Device, buffer::SliceMut, anyhow::Result};
#[kernel]
fn with_groups() {}
fn test_with_groups(device: Device) -> Result<()> {
with_groups::builder()?.build(device)?.with_groups(1).dispatch()
}
#[kernel]
fn with_groups_item(
#[item] y: &mut u32,
) {}
fn test_with_groups_item(y: SliceMut<u32>) -> Result<()> {
with_groups_item::builder()?.build(y.device())?.dispatch(y)
}
}
```
```compile_fail
use krnl::macros::module;
#[module]
#[krnl(no_build)]
mod kernels {
use krnl::{macros::kernel, device::Device, anyhow::Result};
#[kernel]
fn with_groups() {}
fn test_with_groups(device: Device) -> Result<()> {
with_groups::builder()?.build(device)?.dispatch()
}
}
```
*/
#[allow(dead_code)]
enum WithGroups {}

#[module]
pub mod kernels {
use dry::macro_for;
Expand Down

0 comments on commit 0d431b8

Please sign in to comment.