Skip to content

Allow instructions to explicitly specify StorageClasses #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
Expand Down Expand Up @@ -339,6 +339,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def(span, cx);
entry.insert(PointeeDefState::Defined(id));
Expand All @@ -350,6 +351,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def_with_id(cx, span, id)
}
Expand Down
46 changes: 33 additions & 13 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
size: Size,
) -> Option<(SpirvValue, <Self as BackendTypes>::Type)> {
let ptr = ptr.strip_ptrcasts();
let mut leaf_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
let pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("non-pointer type: {other:?}")),
};

// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
let mut leaf_ty = pointee_ty;
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Expand All @@ -429,7 +430,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.then(|| self.type_ptr_to(leaf_ty))?;

let leaf_ptr = if indices.is_empty() {
assert_ty_eq!(self, ptr.ty, leaf_ptr_ty);
// Compare pointee types instead of pointer types as storage class might be different.
assert_ty_eq!(self, pointee_ty, leaf_ty);
ptr
} else {
let indices = indices
Expand Down Expand Up @@ -586,7 +588,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let ptr = ptr.strip_ptrcasts();
let ptr_id = ptr.def(self);
let original_pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
};

Expand Down Expand Up @@ -1926,6 +1928,25 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return ptr;
}

// No cast is needed if only the storage class mismatches.
let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

// FIXME(jwollen) Do we need to choose `dest_ty` if it has a fixed storage class and `ptr` has none?
if ptr_pointee == dest_pointee {
return ptr;
}

// Strip a previous `pointercast`, to reveal the original pointer type.
let ptr = ptr.strip_ptrcasts();

Expand All @@ -1934,17 +1955,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

if ptr_pointee == dest_pointee {
return ptr;
}

let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

if let Some((indices, _)) = self.recover_access_chain_from_offset(
Expand Down Expand Up @@ -2324,7 +2344,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));

let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
_ => self.fatal(format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
Expand Down Expand Up @@ -2696,7 +2716,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
(callee.def(self), return_type, arguments)
}

SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Pointer { pointee, .. } => match self.lookup_type(pointee) {
SpirvType::Function {
return_type,
arguments,
Expand Down
22 changes: 19 additions & 3 deletions crates/rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::spirv::{StorageClass, Word};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
Expand Down Expand Up @@ -104,7 +104,23 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {

// HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness.
pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(self.span(), self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(self.span(), self)
}

pub fn type_ptr_with_storage_class_to(
&self,
ty: Word,
storage_class: StorageClassKind,
) -> Word {
SpirvType::Pointer {
pointee: ty,
storage_class,
}
.def(self.span(), self)
}

// TODO: Definitely add tests to make sure this impl is right.
Expand Down
24 changes: 10 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{BuilderCursor, SpirvValue};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::dr;
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier, reflect};
use rspirv::spirv::{
Expand Down Expand Up @@ -307,19 +307,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
}
.def(self.span(), self),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
self.struct_err("TypePointer in asm! requires `Generic` storage class")
.with_note(format!(
"`{storage_class:?}` storage class was specified"
))
.with_help(format!(
"the storage class will be inferred automatically (e.g. to `{storage_class:?}`)"
))
.emit();
}
// The storage class can be specified explicitly or inferred later by using StorageClass::Generic.
let storage_class = match inst.operands[0].unwrap_storage_class() {
StorageClass::Generic => StorageClassKind::Inferred,
storage_class => StorageClassKind::Explicit(storage_class),
};
SpirvType::Pointer {
pointee: inst.operands[1].unwrap_id_ref(),
storage_class,
}
.def(self.span(), self)
}
Expand Down Expand Up @@ -678,6 +673,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, cx),

Expand Down Expand Up @@ -931,7 +927,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Some(match kind {
TypeofKind::Plain => ty,
TypeofKind::Dereference => match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => {
self.tcx.dcx().span_err(
span,
Expand All @@ -953,7 +949,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.check_reg(span, reg);
if let Some(place) = place {
match self.lookup_type(place.val.llval.ty) {
SpirvType::Pointer { pointee } => Some(pointee),
SpirvType::Pointer { pointee, .. } => Some(pointee),
other => {
self.tcx.dcx().span_err(
span,
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl SpirvValue {
match entry.val {
SpirvConst::PtrTo { pointee } => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
ty => bug!("load called on value that wasn't a pointer: {:?}", ty),
};
// FIXME(eddyb) deduplicate this `if`-`else` and its other copies.
Expand Down
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> {
let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) {
GlobalAlloc::Memory(alloc) => {
let pointee = match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"GlobalAlloc::Memory type not implemented: {}",
other.debug(ty, self)
Expand All @@ -259,7 +259,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> {
.global_alloc(self.tcx.vtable_allocation((vty, dyn_ty.principal())))
.unwrap_memory();
let pointee = match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"GlobalAlloc::VTable type not implemented: {}",
other.debug(ty, self)
Expand Down Expand Up @@ -328,7 +328,7 @@ impl<'tcx> CodegenCx<'tcx> {
if let Some(SpirvConst::ConstDataFromAlloc(alloc)) =
self.builder.lookup_const_by_id(pointee)
{
if let SpirvType::Pointer { pointee } = self.lookup_type(ty) {
if let SpirvType::Pointer { pointee, .. } = self.lookup_type(ty) {
let mut offset = Size::ZERO;
let init = self.read_from_const_alloc(alloc, &mut offset, pointee);
return self.static_addr_of(init, alloc.inner().align, None);
Expand Down
11 changes: 8 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::abi::ConvSpirvType;
use crate::attr::AggregatedSpirvAttributes;
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration};
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
use rustc_attr::InlineAttr;
Expand Down Expand Up @@ -267,7 +267,12 @@ impl<'tcx> CodegenCx<'tcx> {
}

fn declare_global(&self, span: Span, ty: Word) -> SpirvValue {
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self);
// Could be explicitly StorageClass::Private but is inferred anyway.
let ptr_ty = SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(span, self);
// FIXME(eddyb) figure out what the correct storage class is.
let result = self
.emit_global()
Expand Down Expand Up @@ -353,7 +358,7 @@ impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> {
Err(_) => return,
};
let value_ty = match self.lookup_type(g.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"global had non-pointer type {}",
other.debug(g.ty, self)
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ impl<'tcx> CodegenCx<'tcx> {
| SpirvType::Matrix { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::Pointer { pointee: element }
| SpirvType::Pointer { pointee: element, .. }
| SpirvType::InterfaceBlock {
inner_type: element,
} => recurse(cx, element, has_bool, must_be_flat),
Expand Down
15 changes: 12 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod type_;
use crate::builder::{ExtInst, InstructionTable};
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvConst, SpirvValue, SpirvValueKind};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, StorageClassKind, TypeCache};
use crate::symbols::Symbols;
use crate::target::SpirvTarget;

Expand Down Expand Up @@ -234,11 +234,19 @@ impl<'tcx> CodegenCx<'tcx> {
}

pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, self)
}

pub fn type_ptr_to_ext(&self, ty: Word, _address_space: AddressSpace) -> Word {
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, self)
}

/// Zombie system:
Expand Down Expand Up @@ -866,6 +874,7 @@ impl<'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'tcx> {

let ty = SpirvType::Pointer {
pointee: function.ty,
storage_class: StorageClassKind::Inferred,
}
.def(span, self);

Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ impl<'tcx> BaseTypeCodegenMethods<'tcx> for CodegenCx<'tcx> {
}
fn element_type(&self, ty: Self::Type) -> Self::Type {
match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Vector { element, .. } => element,
spirv_type => self.tcx.dcx().fatal(format!(
"element_type called on invalid type: {spirv_type:?}"
Expand Down
14 changes: 11 additions & 3 deletions crates/rustc_codegen_spirv/src/linker/specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,14 @@ impl<'a, S: Specialization> InferCx<'a, S> {

#[allow(clippy::match_same_arms)]
Ok(match (a.clone(), b.clone()) {
// Concrete result types explicitly created inside functions
// can be assigned to instances.
// FIXME(jwollen) do we need to infere instance generics?
(InferOperand::Instance(_), InferOperand::Concrete(new))
| (InferOperand::Concrete(new), InferOperand::Instance(_)) => {
InferOperand::Concrete(new)
}

// Instances of "generic" globals/functions must be of the same ID,
// and their `generic_args` inference variables must be unified.
(
Expand Down Expand Up @@ -1999,13 +2007,13 @@ impl<'a, S: Specialization> InferCx<'a, S> {

if let Some(type_of_result) = type_of_result {
// Keep the (instantiated) *Result Type*, for future instructions to use
// (but only if it has any `InferVar`s at all).
// if it has any `InferVar`s at all or if it was a concrete type.
match type_of_result {
InferOperand::Var(_) | InferOperand::Instance(_) => {
InferOperand::Var(_) | InferOperand::Instance(_) | InferOperand::Concrete(_) => {
self.type_of_result
.insert(inst.result_id.unwrap(), type_of_result);
}
InferOperand::Unknown | InferOperand::Concrete(_) => {}
InferOperand::Unknown => {}
}
}
}
Expand Down
Loading
Loading