From 062bc02dc7af12336c41d04a6e4496aeca61c279 Mon Sep 17 00:00:00 2001 From: bjorn3 <17426603+bjorn3@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:52:23 +0000 Subject: [PATCH 1/2] Avoid using make_direct_deprecated() in extern "ptx-kernel" This method will be removed in the future as it produces a broken ABI that depends on cg_llvm implementation details. After this PR wasm32-unknown-unknown is the only remaining user of make_direct_deprecated(). --- compiler/rustc_target/src/callconv/nvptx64.rs | 39 +++++++++++-------- compiler/rustc_ty_utils/src/abi.rs | 11 ++---- .../nvptx-kernel-args-abi-v7.rs | 20 +--------- 3 files changed, 28 insertions(+), 42 deletions(-) diff --git a/compiler/rustc_target/src/callconv/nvptx64.rs b/compiler/rustc_target/src/callconv/nvptx64.rs index 2e8b16d3a938e..73273396ecf21 100644 --- a/compiler/rustc_target/src/callconv/nvptx64.rs +++ b/compiler/rustc_target/src/callconv/nvptx64.rs @@ -1,5 +1,5 @@ use super::{ArgAttribute, ArgAttributes, ArgExtension, CastTarget}; -use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform}; +use crate::abi::call::{ArgAbi, FnAbi, Reg, Size, Uniform}; use crate::abi::{HasDataLayout, TyAbiInterface}; fn classify_ret(ret: &mut ArgAbi<'_, Ty>) { @@ -53,22 +53,29 @@ where Ty: TyAbiInterface<'a, C> + Copy, C: HasDataLayout, { - if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) { - let align_bytes = arg.layout.align.abi.bytes(); - - let unit = match align_bytes { - 1 => Reg::i8(), - 2 => Reg::i16(), - 4 => Reg::i32(), - 8 => Reg::i64(), - 16 => Reg::i128(), - _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), - }; - arg.cast_to(Uniform::new(unit, Size::from_bytes(2 * align_bytes))); - } else { - // FIXME: find a better way to do this. See https://github.com/rust-lang/rust/issues/117271. - arg.make_direct_deprecated(); + match arg.mode { + super::PassMode::Ignore | super::PassMode::Direct(_) => return, + super::PassMode::Pair(_, _) => {} + super::PassMode::Cast { .. } => unreachable!(), + super::PassMode::Indirect { .. } => {} } + + // FIXME only allow structs and wide pointers here + // panic!( + // "`extern \"ptx-kernel\"` doesn't allow passing types other than primitives and structs" + // ); + + let align_bytes = arg.layout.align.abi.bytes(); + + let unit = match align_bytes { + 1 => Reg::i8(), + 2 => Reg::i16(), + 4 => Reg::i32(), + 8 => Reg::i64(), + 16 => Reg::i128(), + _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), + }; + arg.cast_to(Uniform::new(unit, arg.layout.size)); } pub(crate) fn compute_abi_info(fn_abi: &mut FnAbi<'_, Ty>) { diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index c528179ae0e7a..169f3a78c26a7 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -489,21 +489,16 @@ fn fn_abi_sanity_check<'tcx>( // have to allow it -- but we absolutely shouldn't let any more targets do // that. (Also see .) // - // The unstable abi `PtxKernel` also uses Direct for now. - // It needs to switch to something else before stabilization can happen. - // (See issue: https://github.com/rust-lang/rust/issues/117271) - // - // And finally the unadjusted ABI is ill specified and uses Direct for all - // args, but unfortunately we need it for calling certain LLVM intrinsics. + // The unadjusted ABI also uses Direct for all args and is ill-specified, + // but unfortunately we need it for calling certain LLVM intrinsics. match spec_abi { ExternAbi::Unadjusted => {} - ExternAbi::PtxKernel => {} ExternAbi::C { unwind: _ } if matches!(&*tcx.sess.target.arch, "wasm32" | "wasm64") => {} _ => { panic!( - "`PassMode::Direct` for aggregates only allowed for \"unadjusted\" and \"ptx-kernel\" functions and on wasm\n\ + "`PassMode::Direct` for aggregates only allowed for \"unadjusted\" functions and on wasm\n\ Problematic type: {:#?}", arg.layout, ); diff --git a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs index fb3a325a41f81..21a31b6bb66d8 100644 --- a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs +++ b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs @@ -166,7 +166,7 @@ pub unsafe extern "ptx-kernel" fn f_f32_arg(_a: f32) {} pub unsafe extern "ptx-kernel" fn f_f64_arg(_a: f64) {} // CHECK: .visible .entry f_single_u8_arg( -// CHECK: .param .align 1 .b8 f_single_u8_arg_param_0[1] +// CHECK: .param .u8 f_single_u8_arg_param_0 #[no_mangle] pub unsafe extern "ptx-kernel" fn f_single_u8_arg(_a: SingleU8) {} @@ -242,22 +242,6 @@ pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {} //pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {} // CHECK: .visible .entry f_u32_slice_arg( -// CHECK: .param .u64 f_u32_slice_arg_param_0 -// CHECK: .param .u64 f_u32_slice_arg_param_1 +// CHECK: .param .align 8 .b8 f_u32_slice_arg_param_0[16] #[no_mangle] pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {} - -// CHECK: .visible .entry f_tuple_u8_u8_arg( -// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {} - -// CHECK: .visible .entry f_tuple_u32_u32_arg( -// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {} - -// CHECK: .visible .entry f_tuple_u8_u8_u32_arg( -// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8] -#[no_mangle] -pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {} From 1c1c13a18411f8fd00024c0eb8fc88dee08bda96 Mon Sep 17 00:00:00 2001 From: bjorn3 <17426603+bjorn3@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:46:41 +0000 Subject: [PATCH 2/2] Restore previous ABI for f_single_u8_arg --- compiler/rustc_target/src/callconv/nvptx64.rs | 11 ++++++++++- .../nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_target/src/callconv/nvptx64.rs b/compiler/rustc_target/src/callconv/nvptx64.rs index 73273396ecf21..c64164372a11d 100644 --- a/compiler/rustc_target/src/callconv/nvptx64.rs +++ b/compiler/rustc_target/src/callconv/nvptx64.rs @@ -75,7 +75,16 @@ where 16 => Reg::i128(), _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), }; - arg.cast_to(Uniform::new(unit, arg.layout.size)); + if arg.layout.size.bytes() / align_bytes == 1 { + // Make sure we pass the struct as array at the LLVM IR level and not as a single integer. + arg.cast_to(CastTarget { + prefix: [Some(unit), None, None, None, None, None, None, None], + rest: Uniform::new(unit, Size::ZERO), + attrs: ArgAttributes::new(), + }); + } else { + arg.cast_to(Uniform::new(unit, arg.layout.size)); + } } pub(crate) fn compute_abi_info(fn_abi: &mut FnAbi<'_, Ty>) { diff --git a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs index 21a31b6bb66d8..b3bfc66a5a570 100644 --- a/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs +++ b/tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs @@ -166,7 +166,7 @@ pub unsafe extern "ptx-kernel" fn f_f32_arg(_a: f32) {} pub unsafe extern "ptx-kernel" fn f_f64_arg(_a: f64) {} // CHECK: .visible .entry f_single_u8_arg( -// CHECK: .param .u8 f_single_u8_arg_param_0 +// CHECK: .param .align 1 .b8 f_single_u8_arg_param_0[1] #[no_mangle] pub unsafe extern "ptx-kernel" fn f_single_u8_arg(_a: SingleU8) {}