Skip to content

Commit cc7b98d

Browse files
committed
[WIP] transmute_from!(@allow_shrink)
gherrit-pr-id: I10874e2bc703fb6b7fcdea050b8971de869a850a
1 parent eec3ec9 commit cc7b98d

File tree

4 files changed

+99
-60
lines changed

4 files changed

+99
-60
lines changed

src/layout.rs

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,9 @@ mod cast_from_raw {
623623
/// [cast_from_raw]: crate::pointer::SizeCompat::cast_from_raw
624624
//
625625
// TODO(#1817): Support Sized->Unsized and Unsized->Sized casts
626-
pub(crate) fn cast_from_raw<Src, Dst>(src: PtrInner<'_, Src>) -> PtrInner<'_, Dst>
626+
pub(crate) fn cast_from_raw<const ALLOW_SHRINK: bool, Src, Dst>(
627+
src: PtrInner<'_, Src>,
628+
) -> PtrInner<'_, Dst>
627629
where
628630
Src: KnownLayout<PointerMetadata = usize> + ?Sized,
629631
Dst: KnownLayout<PointerMetadata = usize> + ?Sized,
@@ -694,12 +696,19 @@ mod cast_from_raw {
694696
/// `Src`'s alignment must not be smaller than `Dst`'s alignment.
695697
#[derive(Copy, Clone)]
696698
struct CastParams {
697-
offset_delta_elems: usize,
698-
elem_multiple: usize,
699+
// `offset_delta / dst.elem_size = offset_delta_elems_num / denom`
700+
offset_delta_elems_num: usize,
701+
// `src.elem_size / dst.elem_size = elem_multiple_num / denom`
702+
elem_multiple_num: usize,
703+
denom: NonZeroUsize,
699704
}
700705

701706
impl CastParams {
702-
const fn try_compute(src: &DstLayout, dst: &DstLayout) -> Option<CastParams> {
707+
const fn try_compute(
708+
src: &DstLayout,
709+
dst: &DstLayout,
710+
allow_shrink: bool,
711+
) -> Option<CastParams> {
703712
if src.align.get() < dst.align.get() {
704713
return None;
705714
}
@@ -724,33 +733,60 @@ mod cast_from_raw {
724733
return None;
725734
};
726735

727-
// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
728-
#[allow(clippy::arithmetic_side_effects)]
729-
let delta_mod_other_elem = offset_delta % dst_elem_size.get();
736+
const fn gcd(a: usize, b: usize) -> usize {
737+
if a == 0 {
738+
b
739+
} else {
740+
#[allow(clippy::arithmetic_side_effects)]
741+
gcd(b % a, a)
742+
}
743+
}
730744

731-
// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
745+
let gcd = gcd(gcd(offset_delta, src.elem_size), dst_elem_size.get());
746+
// PANICS: `dst_elem_size.get()` is non-zero, and so `denom`
747+
// will be non-zero.
748+
749+
#[allow(clippy::arithmetic_side_effects)]
750+
let offset_delta_elems_num = offset_delta / gcd;
732751
#[allow(clippy::arithmetic_side_effects)]
733-
let elem_remainder = src.elem_size % dst_elem_size.get();
752+
let elem_multiple_num = src.elem_size / gcd;
753+
// PANICS: `dst_elem_size` is non-zero, and `gcd` is no greater
754+
// than it by construction. Thus, this should be at least 1.
755+
let denom = NonZeroUsize::new(dst_elem_size.get() / gcd)
756+
.expect("CastParams::try_compute: denom should be non-zero");
734757

735-
if delta_mod_other_elem != 0 || src.elem_size < dst.elem_size || elem_remainder != 0
736-
{
758+
if denom.get() != 1 && !allow_shrink {
737759
return None;
738760
}
739761

740-
// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
741-
#[allow(clippy::arithmetic_side_effects)]
742-
let offset_delta_elems = offset_delta / dst_elem_size.get();
762+
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
763+
// #[allow(clippy::arithmetic_side_effects)]
764+
// let delta_mod_other_elem = offset_delta % dst_elem_size.get();
743765

744-
// PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
745-
#[allow(clippy::arithmetic_side_effects)]
746-
let elem_multiple = src.elem_size / dst_elem_size.get();
766+
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
767+
// #[allow(clippy::arithmetic_side_effects)]
768+
// let elem_remainder = src.elem_size % dst_elem_size.get();
769+
770+
// if delta_mod_other_elem != 0 || src.elem_size < dst.elem_size || elem_remainder != 0
771+
// {
772+
// return None;
773+
// }
774+
775+
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
776+
// #[allow(clippy::arithmetic_side_effects)]
777+
// let offset_delta_elems = offset_delta / dst_elem_size.get();
778+
779+
// // PANICS: `dst_elem_size: NonZeroUsize`, so this won't div by zero.
780+
// #[allow(clippy::arithmetic_side_effects)]
781+
// let elem_multiple = src.elem_size / dst_elem_size.get();
747782

748783
// SAFETY: We checked above that `src.align >= dst.align`.
749784
Some(CastParams {
750785
// SAFETY: We checked above that this is an exact ratio.
751-
offset_delta_elems,
786+
offset_delta_elems_num,
752787
// SAFETY: We checked above that this is an exact ratio.
753-
elem_multiple,
788+
elem_multiple_num,
789+
denom,
754790
})
755791
}
756792

@@ -774,24 +810,25 @@ mod cast_from_raw {
774810
// metadata, this math will not overflow, and the returned value
775811
// will describe a `Dst` of the same size.
776812
#[allow(unstable_name_collisions)]
777-
unsafe {
778-
self.offset_delta_elems
779-
.unchecked_add(src_meta.unchecked_mul(self.elem_multiple))
780-
}
813+
let num = unsafe {
814+
self.offset_delta_elems_num
815+
.unchecked_add(src_meta.unchecked_mul(self.elem_multiple_num))
816+
};
817+
num / self.denom.get()
781818
}
782819
}
783820

784-
trait Params<Src: ?Sized> {
821+
trait Params<const ALLOW_SHRINK: bool, Src: ?Sized> {
785822
const CAST_PARAMS: CastParams;
786823
}
787824

788-
impl<Src, Dst> Params<Src> for Dst
825+
impl<const ALLOW_SHRINK: bool, Src, Dst> Params<ALLOW_SHRINK, Src> for Dst
789826
where
790827
Src: KnownLayout + ?Sized,
791828
Dst: KnownLayout<PointerMetadata = usize> + ?Sized,
792829
{
793830
const CAST_PARAMS: CastParams =
794-
match CastParams::try_compute(&Src::LAYOUT, &Dst::LAYOUT) {
831+
match CastParams::try_compute(&Src::LAYOUT, &Dst::LAYOUT, ALLOW_SHRINK) {
795832
Some(params) => params,
796833
None => const_panic!(
797834
"cannot `transmute_ref!` or `transmute_mut!` between incompatible types"
@@ -800,7 +837,7 @@ mod cast_from_raw {
800837
}
801838

802839
let src_meta = <Src as KnownLayout>::pointer_to_metadata(src.as_non_null().as_ptr());
803-
let params = <Dst as Params<Src>>::CAST_PARAMS;
840+
let params = <Dst as Params<ALLOW_SHRINK, Src>>::CAST_PARAMS;
804841

805842
// SAFETY: `src: PtrInner`, and so by invariant on `PtrInner`, `src`'s
806843
// referent is no larger than `isize::MAX`.

src/macros.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ macro_rules! transmute_ref {
263263
// - `Src: IntoBytes + Immutable`
264264
// - `Dst: FromBytes + Immutable`
265265
unsafe {
266-
t.transmute_ref()
266+
t.transmute_ref::<false>()
267267
}
268268
}
269269
}}

src/util/macro_util.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,9 @@ impl<'a, Src, Dst> Wrap<&'a Src, &'a Dst> {
713713
/// - `mem::align_of::<Dst>() <= mem::align_of::<Src>()`
714714
#[inline(always)]
715715
#[must_use]
716-
pub const unsafe fn transmute_ref(self) -> &'a Dst {
716+
pub const unsafe fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Dst {
717+
// TODO: Permit ALLOW_SHRINK = true and do a different static_assert!
718+
// (namely, size_of::<Dst>() <= size_of::<Src>()).
717719
static_assert!(Src, Dst => mem::size_of::<Dst>() == mem::size_of::<Src>());
718720
static_assert!(Src, Dst => mem::align_of::<Dst>() <= mem::align_of::<Src>());
719721

@@ -777,7 +779,7 @@ pub trait TransmuteRefDst<'a> {
777779
type Dst: ?Sized;
778780

779781
#[must_use]
780-
fn transmute_ref(self) -> &'a Self::Dst;
782+
fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Self::Dst;
781783
}
782784

783785
impl<'a, Src: ?Sized, Dst: ?Sized> TransmuteRefDst<'a> for Wrap<&'a Src, &'a Dst>
@@ -788,18 +790,18 @@ where
788790
type Dst = Dst;
789791

790792
#[inline(always)]
791-
fn transmute_ref(self) -> &'a Dst {
793+
fn transmute_ref<const ALLOW_SHRINK: bool>(self) -> &'a Dst {
792794
static_assert!(Src: ?Sized + KnownLayout, Dst: ?Sized + KnownLayout => {
793795
Src::LAYOUT.align.get() >= Dst::LAYOUT.align.get()
794796
}, "cannot transmute reference when destination type has higher alignment than source type");
795797

796798
// SAFETY: We only use `S` as `S<Src>` and `D` as `D<Dst>`.
797799
unsafe {
798-
unsafe_with_size_compat!(<S<Src>, D<Dst>> {
800+
unsafe_with_size_compat!(<S<Src>, D<ALLOW_SHRINK, Dst>>, {
799801
let ptr = Ptr::from_ref(self.0)
800802
.transmute::<S<Src>, invariant::Valid, BecauseImmutable>()
801803
.recall_validity::<invariant::Initialized, _>()
802-
.transmute::<D<Dst>, invariant::Initialized, (crate::pointer::BecauseMutationCompatible, _)>()
804+
.transmute::<D<{ALLOW_SHRINK}, Dst>, invariant::Initialized, (crate::pointer::BecauseMutationCompatible, _)>()
803805
.recall_validity::<invariant::Valid, _>();
804806

805807
#[allow(unused_unsafe)]
@@ -817,7 +819,7 @@ where
817819
pub trait TransmuteMutDst<'a> {
818820
type Dst: ?Sized;
819821
#[must_use]
820-
fn transmute_mut(self) -> &'a mut Self::Dst;
822+
fn transmute_mut<const ALLOW_SHRINK: bool>(self) -> &'a mut Self::Dst;
821823
}
822824

823825
impl<'a, Src: ?Sized, Dst: ?Sized> TransmuteMutDst<'a> for Wrap<&'a mut Src, &'a mut Dst>
@@ -828,18 +830,18 @@ where
828830
type Dst = Dst;
829831

830832
#[inline(always)]
831-
fn transmute_mut(self) -> &'a mut Dst {
833+
fn transmute_mut<const ALLOW_SHRINK: bool>(self) -> &'a mut Dst {
832834
static_assert!(Src: ?Sized + KnownLayout, Dst: ?Sized + KnownLayout => {
833835
Src::LAYOUT.align.get() >= Dst::LAYOUT.align.get()
834836
}, "cannot transmute reference when destination type has higher alignment than source type");
835837

836838
// SAFETY: We only use `S` as `S<Src>` and `D` as `D<Dst>`.
837839
unsafe {
838-
unsafe_with_size_compat!(<S<Src>, D<Dst>> {
840+
unsafe_with_size_compat!(<S<Src>, D<ALLOW_SHRINK, Dst>>, {
839841
let ptr = Ptr::from_mut(self.0)
840842
.transmute::<S<Src>, invariant::Valid, _>()
841843
.recall_validity::<invariant::Initialized, (_, (_, _))>()
842-
.transmute::<D<Dst>, invariant::Initialized, _>()
844+
.transmute::<D<{ALLOW_SHRINK}, Dst>, invariant::Initialized, _>()
843845
.recall_validity::<invariant::Valid, (_, (_, _))>();
844846

845847
#[allow(unused_unsafe)]

0 commit comments

Comments
 (0)