Skip to content

[WIP] transmute_from!(@allow_shrink) #2487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: I32ffeea758b53073aa461ab41c217e5b8f0bc4e4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 66 additions & 27 deletions src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,9 @@ mod cast_from_raw {
/// [cast_from_raw]: crate::pointer::SizeCompat::cast_from_raw
//
// TODO(#1817): Support Sized->Unsized and Unsized->Sized casts
pub(crate) fn cast_from_raw<Src, Dst>(src: PtrInner<'_, Src>) -> PtrInner<'_, Dst>
pub(crate) fn cast_from_raw<Src, Dst, const ALLOW_SHRINK: bool>(
src: PtrInner<'_, Src>,
) -> PtrInner<'_, Dst>
where
Src: KnownLayout<PointerMetadata = usize> + ?Sized,
Dst: KnownLayout<PointerMetadata = usize> + ?Sized,
Expand Down Expand Up @@ -694,12 +696,19 @@ mod cast_from_raw {
/// `Src`'s alignment must not be smaller than `Dst`'s alignment.
#[derive(Copy, Clone)]
struct CastParams {
offset_delta_elems: usize,
elem_multiple: usize,
// `offset_delta / dst.elem_size = offset_delta_elems_num / denom`
offset_delta_elems_num: usize,
// `src.elem_size / dst.elem_size = elem_multiple_num / denom`
elem_multiple_num: usize,
denom: NonZeroUsize,
}

impl CastParams {
const fn try_compute(src: &DstLayout, dst: &DstLayout) -> Option<CastParams> {
const fn try_compute(
src: &DstLayout,
dst: &DstLayout,
allow_shrink: bool,
) -> Option<CastParams> {
if src.align.get() < dst.align.get() {
return None;
}
Expand All @@ -724,33 +733,62 @@ mod cast_from_raw {
return None;
};

// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
#[allow(clippy::arithmetic_side_effects)]
let delta_mod_other_elem = offset_delta % dst_elem_size.get();
const fn gcd(a: usize, b: usize) -> usize {
if a == 0 {
b
} else {
#[allow(clippy::arithmetic_side_effects)]
gcd(b % a, a)
}
}

// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
let gcd = gcd(gcd(offset_delta, src.elem_size), dst_elem_size.get());
// PANICS: `dst_elem_size.get()` is non-zero, and so `denom`
// will be non-zero.

#[allow(clippy::arithmetic_side_effects)]
let offset_delta_elems_num = offset_delta / gcd;
#[allow(clippy::arithmetic_side_effects)]
let elem_remainder = src.elem_size % dst_elem_size.get();
let elem_multiple_num = src.elem_size / gcd;
// PANICS: `dst_elem_size` is non-zero, and `gcd` is no greater
// than it by construction. Thus, this should be at least 1.
let denom = match NonZeroUsize::new(dst_elem_size.get() / gcd) {
Some(d) => d,
None => const_panic!("CastParams::try_compute: denom should be non-zero"),
};

if delta_mod_other_elem != 0 || src.elem_size < dst.elem_size || elem_remainder != 0
{
if denom.get() != 1 && !allow_shrink {
return None;
}

// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
#[allow(clippy::arithmetic_side_effects)]
let offset_delta_elems = offset_delta / dst_elem_size.get();
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
// #[allow(clippy::arithmetic_side_effects)]
// let delta_mod_other_elem = offset_delta % dst_elem_size.get();

// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
#[allow(clippy::arithmetic_side_effects)]
let elem_multiple = src.elem_size / dst_elem_size.get();
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
// #[allow(clippy::arithmetic_side_effects)]
// let elem_remainder = src.elem_size % dst_elem_size.get();

// if delta_mod_other_elem != 0 || src.elem_size < dst.elem_size || elem_remainder != 0
// {
// return None;
// }

// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
// #[allow(clippy::arithmetic_side_effects)]
// let offset_delta_elems = offset_delta / dst_elem_size.get();

// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
// #[allow(clippy::arithmetic_side_effects)]
// let elem_multiple = src.elem_size / dst_elem_size.get();

// SAFETY: We checked above that `src.align >= dst.align`.
Some(CastParams {
// SAFETY: We checked above that this is an exact ratio.
offset_delta_elems,
offset_delta_elems_num,
// SAFETY: We checked above that this is an exact ratio.
elem_multiple,
elem_multiple_num,
denom,
})
}

Expand All @@ -774,24 +812,25 @@ mod cast_from_raw {
// metadata, this math will not overflow, and the returned value
// will describe a `Dst` of the same size.
#[allow(unstable_name_collisions)]
unsafe {
self.offset_delta_elems
.unchecked_add(src_meta.unchecked_mul(self.elem_multiple))
}
let num = unsafe {
self.offset_delta_elems_num
.unchecked_add(src_meta.unchecked_mul(self.elem_multiple_num))
};
num / self.denom.get()
}
}

trait Params<Src: ?Sized> {
trait Params<Src: ?Sized, const ALLOW_SHRINK: bool> {
const CAST_PARAMS: CastParams;
}

impl<Src, Dst> Params<Src> for Dst
impl<Src, Dst, const ALLOW_SHRINK: bool> Params<Src, ALLOW_SHRINK> for Dst
where
Src: KnownLayout + ?Sized,
Dst: KnownLayout<PointerMetadata = usize> + ?Sized,
{
const CAST_PARAMS: CastParams =
match CastParams::try_compute(&Src::LAYOUT, &Dst::LAYOUT) {
match CastParams::try_compute(&Src::LAYOUT, &Dst::LAYOUT, ALLOW_SHRINK) {
Some(params) => params,
None => const_panic!(
"cannot `transmute_ref!` or `transmute_mut!` between incompatible types"
Expand All @@ -800,7 +839,7 @@ mod cast_from_raw {
}

let src_meta = <Src as KnownLayout>::pointer_to_metadata(src.as_non_null().as_ptr());
let params = <Dst as Params<Src>>::CAST_PARAMS;
let params = <Dst as Params<Src, ALLOW_SHRINK>>::CAST_PARAMS;

// SAFETY: `src: PtrInner`, and so by invariant on `PtrInner`, `src`'s
// referent is no larger than `isize::MAX`.
Expand Down
51 changes: 46 additions & 5 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,18 @@ macro_rules! transmute {
/// `Dst: Sized`.
#[macro_export]
macro_rules! transmute_ref {
($e:expr) => {{
(#![allow(shrink)] $e:expr) => {
$crate::__transmute_ref_inner!(true, $e)
};
($e:expr) => {
$crate::__transmute_ref_inner!(false, $e)
};
}

#[macro_export]
#[doc(hidden)]
macro_rules! __transmute_ref_inner {
($allow_shrink:literal, $e:expr) => {{
// NOTE: This must be a macro (rather than a function with trait bounds)
// because there's no way, in a generic context, to enforce that two
// types have the same size or alignment.
Expand Down Expand Up @@ -263,10 +274,10 @@ macro_rules! transmute_ref {
// - `Src: IntoBytes + Immutable`
// - `Dst: FromBytes + Immutable`
unsafe {
t.transmute_ref()
t.transmute_ref::<$allow_shrink>()
}
}
}}
}};
}

/// Safely transmutes a mutable reference of one type to a mutable reference of
Expand Down Expand Up @@ -400,7 +411,18 @@ macro_rules! transmute_ref {
/// ```
#[macro_export]
macro_rules! transmute_mut {
($e:expr) => {{
(#![allow(shrink)] $e:expr) => {
$crate::__transmute_mut_inner!(true, $e)
};
($e:expr) => {
$crate::__transmute_mut_inner!(false, $e)
};
}

#[doc(hidden)]
#[macro_export]
macro_rules! __transmute_mut_inner {
($allow_shrink:literal, $e:expr) => {{
// NOTE: This must be a macro (rather than a function with trait bounds)
// because, for backwards-compatibility on v0.8.x, we use the autoref
// specialization trick to dispatch to different `transmute_mut`
Expand All @@ -414,7 +436,7 @@ macro_rules! transmute_mut {
#[allow(unused)]
use $crate::util::macro_util::TransmuteMutDst as _;
let t = $crate::util::macro_util::Wrap::new(e);
t.transmute_mut()
t.transmute_mut::<$allow_shrink>()
}}
}

Expand Down Expand Up @@ -1162,6 +1184,15 @@ mod tests {
let x: &SliceDst<U16, u8> = transmute_ref!(slice_dst_big);
assert_eq!(x, slice_dst_small);

let bytes = &[0, 1, 2, 3, 4, 5, 6, 7][..];
let slice_dst_big = SliceDst::<[u8; 4], [u8; 4]>::ref_from_bytes(bytes).unwrap();
let slice_dst_small = SliceDst::<[u8; 3], [u8; 3]>::ref_from_bytes(&bytes[..6]).unwrap();
let x: &SliceDst<[u8; 3], [u8; 3]> = transmute_ref!(
#![allow(shrink)]
slice_dst_big
);
assert_eq!(x, slice_dst_small);

// Test that it's legal to transmute a reference while shrinking the
// lifetime (note that `X` has the lifetime `'static`).
let x: &[u8; 8] = transmute_ref!(X);
Expand Down Expand Up @@ -1350,6 +1381,16 @@ mod tests {
let slice_dst_small = SliceDst::<U16, u8>::mut_from_bytes(&mut bytes[..]).unwrap();
let x: &mut SliceDst<U16, u8> = transmute_mut!(slice_dst_big);
assert_eq!(x, slice_dst_small);

let mut bytes = [0, 1, 2, 3, 4, 5, 6, 7];
let slice_dst_big = SliceDst::<[u8; 4], [u8; 4]>::mut_from_bytes(&mut bytes[..]).unwrap();
let mut bytes = [0, 1, 2, 3, 4, 5];
let slice_dst_small = SliceDst::<[u8; 3], [u8; 3]>::mut_from_bytes(&mut bytes[..]).unwrap();
let x: &mut SliceDst<[u8; 3], [u8; 3]> = transmute_mut!(
#![allow(shrink)]
slice_dst_big
);
assert_eq!(x, slice_dst_small);
}

#[test]
Expand Down
24 changes: 14 additions & 10 deletions src/util/macro_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,9 @@ impl<'a, Src, Dst> Wrap<&'a Src, &'a Dst> {
/// - `mem::align_of::<Dst>() <= mem::align_of::<Src>()`
#[inline(always)]
#[must_use]
pub const unsafe fn transmute_ref(self) -> &'a Dst {
pub const unsafe fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Dst {
// TODO: Permit ALLOW_SHRINK = true and do a different static_assert!
// (namely, size_of::<Dst>() <= size_of::<Src>()).
static_assert!(Src, Dst => mem::size_of::<Dst>() == mem::size_of::<Src>());
static_assert!(Src, Dst => mem::align_of::<Dst>() <= mem::align_of::<Src>());

Expand Down Expand Up @@ -750,11 +752,13 @@ impl<'a, Src, Dst> Wrap<&'a mut Src, &'a mut Dst> {
/// - `mem::align_of::<Dst>() <= mem::align_of::<Src>()`
#[inline(always)]
#[must_use]
pub fn transmute_mut(self) -> &'a mut Dst
pub fn transmute_mut<const ALLOW_SHRINK: bool>(self) -> &'a mut Dst
where
Src: FromBytes + IntoBytes,
Dst: FromBytes + IntoBytes,
{
// TODO: Permit ALLOW_SHRINK = true and do a different static_assert!
// (namely, size_of::<Dst>() <= size_of::<Src>()).
static_assert!(Src, Dst => mem::size_of::<Dst>() == mem::size_of::<Src>());
static_assert!(Src, Dst => mem::align_of::<Dst>() <= mem::align_of::<Src>());

Expand All @@ -777,7 +781,7 @@ pub trait TransmuteRefDst<'a> {
type Dst: ?Sized;

#[must_use]
fn transmute_ref(self) -> &'a Self::Dst;
fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Self::Dst;
}

impl<'a, Src: ?Sized, Dst: ?Sized> TransmuteRefDst<'a> for Wrap<&'a Src, &'a Dst>
Expand All @@ -788,18 +792,18 @@ where
type Dst = Dst;

#[inline(always)]
fn transmute_ref(self) -> &'a Dst {
fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Dst {
static_assert!(Src: ?Sized + KnownLayout, Dst: ?Sized + KnownLayout => {
Src::LAYOUT.align.get() >= Dst::LAYOUT.align.get()
}, "cannot transmute reference when destination type has higher alignment than source type");

// SAFETY: We only use `S` as `S<Src>` and `D` as `D<Dst>`.
unsafe {
unsafe_with_size_compat!(<S<Src>, D<Dst>> {
unsafe_with_size_compat!(<S<Src>, D<Dst, {ALLOW_SHRINK}>>, {
let ptr = Ptr::from_ref(self.0)
.transmute::<S<Src>, invariant::Valid, BecauseImmutable>()
.recall_validity::<invariant::Initialized, _>()
.transmute::<D<Dst>, invariant::Initialized, (crate::pointer::BecauseMutationCompatible, _)>()
.transmute::<D<Dst, {ALLOW_SHRINK}>, invariant::Initialized, (crate::pointer::BecauseMutationCompatible, _)>()
.recall_validity::<invariant::Valid, _>();

#[allow(unused_unsafe)]
Expand All @@ -817,7 +821,7 @@ where
pub trait TransmuteMutDst<'a> {
type Dst: ?Sized;
#[must_use]
fn transmute_mut(self) -> &'a mut Self::Dst;
fn transmute_mut<const ALLOW_SHRINK: bool>(self) -> &'a mut Self::Dst;
}

impl<'a, Src: ?Sized, Dst: ?Sized> TransmuteMutDst<'a> for Wrap<&'a mut Src, &'a mut Dst>
Expand All @@ -828,18 +832,18 @@ where
type Dst = Dst;

#[inline(always)]
fn transmute_mut(self) -> &'a mut Dst {
fn transmute_mut<const ALLOW_SHRINK: bool>(self) -> &'a mut Dst {
static_assert!(Src: ?Sized + KnownLayout, Dst: ?Sized + KnownLayout => {
Src::LAYOUT.align.get() >= Dst::LAYOUT.align.get()
}, "cannot transmute reference when destination type has higher alignment than source type");

// SAFETY: We only use `S` as `S<Src>` and `D` as `D<Dst>`.
unsafe {
unsafe_with_size_compat!(<S<Src>, D<Dst>> {
unsafe_with_size_compat!(<S<Src>, D<Dst, {ALLOW_SHRINK}>>, {
let ptr = Ptr::from_mut(self.0)
.transmute::<S<Src>, invariant::Valid, _>()
.recall_validity::<invariant::Initialized, (_, (_, _))>()
.transmute::<D<Dst>, invariant::Initialized, _>()
.transmute::<D<Dst, {ALLOW_SHRINK}>, invariant::Initialized, _>()
.recall_validity::<invariant::Valid, (_, (_, _))>();

#[allow(unused_unsafe)]
Expand Down
Loading
Loading