Skip to content

Commit af4d6c7

Browse files
committed
interpret: refactor dyn trait handling
We can check that the vtable is for the right trait very early, and then just pass the type around.
1 parent 0de24a5 commit af4d6c7

File tree

7 files changed

+90
-112
lines changed

7 files changed

+90
-112
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+1-8
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
383383
match (&src_pointee_ty.kind(), &dest_pointee_ty.kind()) {
384384
(&ty::Array(_, length), &ty::Slice(_)) => {
385385
let ptr = self.read_pointer(src)?;
386-
// u64 cast is from usize to u64, which is always good
387386
let val = Immediate::new_slice(
388387
ptr,
389388
length.eval_target_usize(*self.tcx, self.param_env),
@@ -401,13 +400,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
401400
let (old_data, old_vptr) = val.to_scalar_pair();
402401
let old_data = old_data.to_pointer(self)?;
403402
let old_vptr = old_vptr.to_pointer(self)?;
404-
let (ty, old_trait) = self.get_ptr_vtable(old_vptr)?;
405-
if old_trait != data_a.principal() {
406-
throw_ub!(InvalidVTableTrait {
407-
expected_trait: data_a,
408-
vtable_trait: old_trait,
409-
});
410-
}
403+
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
411404
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
412405
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
413406
}

compiler/rustc_const_eval/src/interpret/memory.rs

+14-5
Original file line numberDiff line numberDiff line change
@@ -867,19 +867,28 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
867867
.ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into())
868868
}
869869

870-
pub fn get_ptr_vtable(
870+
/// Get the dynamic type of the given vtable pointer.
871+
/// If `expected_trait` is `Some`, it must be a vtable for the given trait.
872+
pub fn get_ptr_vtable_ty(
871873
&self,
872874
ptr: Pointer<Option<M::Provenance>>,
873-
) -> InterpResult<'tcx, (Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>)> {
875+
expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
876+
) -> InterpResult<'tcx, Ty<'tcx>> {
874877
trace!("get_ptr_vtable({:?})", ptr);
875878
let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?;
876879
if offset.bytes() != 0 {
877880
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
878881
}
879-
match self.tcx.try_get_global_alloc(alloc_id) {
880-
Some(GlobalAlloc::VTable(ty, trait_ref)) => Ok((ty, trait_ref)),
881-
_ => throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))),
882+
let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
883+
else {
884+
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
885+
};
886+
if let Some(expected_trait) = expected_trait {
887+
if vtable_trait != expected_trait.principal() {
888+
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
889+
}
882890
}
891+
Ok(ty)
883892
}
884893

885894
pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> {

compiler/rustc_const_eval/src/interpret/place.rs

-49
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use tracing::{instrument, trace};
99

1010
use rustc_ast::Mutability;
1111
use rustc_middle::mir;
12-
use rustc_middle::ty;
1312
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1413
use rustc_middle::ty::Ty;
1514
use rustc_middle::{bug, span_bug};
@@ -1017,54 +1016,6 @@ where
10171016
let layout = self.layout_of(raw.ty)?;
10181017
Ok(self.ptr_to_mplace(ptr.into(), layout))
10191018
}
1020-
1021-
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
1022-
/// Aso returns the vtable.
1023-
pub(super) fn unpack_dyn_trait(
1024-
&self,
1025-
mplace: &MPlaceTy<'tcx, M::Provenance>,
1026-
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
1027-
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, Pointer<Option<M::Provenance>>)> {
1028-
assert!(
1029-
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
1030-
"`unpack_dyn_trait` only makes sense on `dyn*` types"
1031-
);
1032-
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
1033-
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
1034-
if expected_trait.principal() != vtable_trait {
1035-
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
1036-
}
1037-
// This is a kind of transmute, from a place with unsized type and metadata to
1038-
// a place with sized type and no metadata.
1039-
let layout = self.layout_of(ty)?;
1040-
let mplace =
1041-
MPlaceTy { mplace: MemPlace { meta: MemPlaceMeta::None, ..mplace.mplace }, layout };
1042-
Ok((mplace, vtable))
1043-
}
1044-
1045-
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
1046-
/// Also returns the vtable.
1047-
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
1048-
&self,
1049-
val: &P,
1050-
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
1051-
) -> InterpResult<'tcx, (P, Pointer<Option<M::Provenance>>)> {
1052-
assert!(
1053-
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
1054-
"`unpack_dyn_star` only makes sense on `dyn*` types"
1055-
);
1056-
let data = self.project_field(val, 0)?;
1057-
let vtable = self.project_field(val, 1)?;
1058-
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
1059-
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
1060-
if expected_trait.principal() != vtable_trait {
1061-
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
1062-
}
1063-
// `data` is already the right thing but has the wrong type. So we transmute it.
1064-
let layout = self.layout_of(ty)?;
1065-
let data = data.transmute(layout, self)?;
1066-
Ok((data, vtable))
1067-
}
10681019
}
10691020

10701021
// Some nodes are used a lot. Make sure they don't unintentionally get bigger.

compiler/rustc_const_eval/src/interpret/terminator.rs

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::borrow::Cow;
22

33
use either::Either;
4+
use rustc_middle::ty::TyCtxt;
45
use tracing::trace;
56

67
use rustc_middle::span_bug;
@@ -827,20 +828,19 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
827828
};
828829

829830
// Obtain the underlying trait we are working on, and the adjusted receiver argument.
830-
let (vptr, dyn_ty, adjusted_receiver) = if let ty::Dynamic(data, _, ty::DynStar) =
831+
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
831832
receiver_place.layout.ty.kind()
832833
{
833-
let (recv, vptr) = self.unpack_dyn_star(&receiver_place, data)?;
834-
let (dyn_ty, _dyn_trait) = self.get_ptr_vtable(vptr)?;
834+
let recv = self.unpack_dyn_star(&receiver_place, data)?;
835835

836-
(vptr, dyn_ty, recv.ptr())
836+
(data.principal(), recv.layout.ty, recv.ptr())
837837
} else {
838838
// Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`.
839839
// (For that reason we also cannot use `unpack_dyn_trait`.)
840840
let receiver_tail = self
841841
.tcx
842842
.struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env);
843-
let ty::Dynamic(data, _, ty::Dyn) = receiver_tail.kind() else {
843+
let ty::Dynamic(receiver_trait, _, ty::Dyn) = receiver_tail.kind() else {
844844
span_bug!(
845845
self.cur_span(),
846846
"dynamic call on non-`dyn` type {}",
@@ -851,25 +851,24 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
851851

852852
// Get the required information from the vtable.
853853
let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?;
854-
let (dyn_ty, dyn_trait) = self.get_ptr_vtable(vptr)?;
855-
if dyn_trait != data.principal() {
856-
throw_ub!(InvalidVTableTrait {
857-
expected_trait: data,
858-
vtable_trait: dyn_trait,
859-
});
860-
}
854+
let dyn_ty = self.get_ptr_vtable_ty(vptr, Some(receiver_trait))?;
861855

862856
// It might be surprising that we use a pointer as the receiver even if this
863857
// is a by-val case; this works because by-val passing of an unsized `dyn
864858
// Trait` to a function is actually desugared to a pointer.
865-
(vptr, dyn_ty, receiver_place.ptr())
859+
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
866860
};
867861

868862
// Now determine the actual method to call. We can do that in two different ways and
869863
// compare them to ensure everything fits.
870-
let Some(ty::VtblEntry::Method(fn_inst)) =
871-
self.get_vtable_entries(vptr)?.get(idx).copied()
872-
else {
864+
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
865+
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
866+
let trait_ref = self.tcx.erase_regions(trait_ref);
867+
self.tcx.vtable_entries(trait_ref)
868+
} else {
869+
TyCtxt::COMMON_VTABLE_ENTRIES
870+
};
871+
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
873872
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
874873
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
875874
};
@@ -898,7 +897,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
898897
let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty);
899898
args[0] = FnArg::Copy(
900899
ImmTy::from_immediate(
901-
Scalar::from_maybe_pointer(adjusted_receiver, self).into(),
900+
Scalar::from_maybe_pointer(adjusted_recv, self).into(),
902901
self.layout_of(receiver_ty)?,
903902
)
904903
.into(),
@@ -974,11 +973,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
974973
let place = match place.layout.ty.kind() {
975974
ty::Dynamic(data, _, ty::Dyn) => {
976975
// Dropping a trait object. Need to find actual drop fn.
977-
self.unpack_dyn_trait(&place, data)?.0
976+
self.unpack_dyn_trait(&place, data)?
978977
}
979978
ty::Dynamic(data, _, ty::DynStar) => {
980979
// Dropping a `dyn*`. Need to find actual drop fn.
981-
self.unpack_dyn_star(&place, data)?.0
980+
self.unpack_dyn_star(&place, data)?
982981
}
983982
_ => {
984983
debug_assert_eq!(

compiler/rustc_const_eval/src/interpret/traits.rs

+48-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use rustc_middle::mir::interpret::{InterpResult, Pointer};
22
use rustc_middle::ty::layout::LayoutOf;
3-
use rustc_middle::ty::{self, Ty, TyCtxt};
3+
use rustc_middle::ty::{self, Ty};
44
use rustc_target::abi::{Align, Size};
55
use tracing::trace;
66

77
use super::util::ensure_monomorphic_enough;
8-
use super::{InterpCx, Machine};
8+
use super::{InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};
99

1010
impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
1111
/// Creates a dynamic vtable for the given type and vtable origin. This is used only for
@@ -33,28 +33,58 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
3333
Ok(vtable_ptr.into())
3434
}
3535

36-
/// Returns a high-level representation of the entries of the given vtable.
37-
pub fn get_vtable_entries(
38-
&self,
39-
vtable: Pointer<Option<M::Provenance>>,
40-
) -> InterpResult<'tcx, &'tcx [ty::VtblEntry<'tcx>]> {
41-
let (ty, poly_trait_ref) = self.get_ptr_vtable(vtable)?;
42-
Ok(if let Some(poly_trait_ref) = poly_trait_ref {
43-
let trait_ref = poly_trait_ref.with_self_ty(*self.tcx, ty);
44-
let trait_ref = self.tcx.erase_regions(trait_ref);
45-
self.tcx.vtable_entries(trait_ref)
46-
} else {
47-
TyCtxt::COMMON_VTABLE_ENTRIES
48-
})
49-
}
50-
5136
pub fn get_vtable_size_and_align(
5237
&self,
5338
vtable: Pointer<Option<M::Provenance>>,
5439
) -> InterpResult<'tcx, (Size, Align)> {
55-
let (ty, _trait_ref) = self.get_ptr_vtable(vtable)?;
40+
let ty = self.get_ptr_vtable_ty(vtable, None)?;
5641
let layout = self.layout_of(ty)?;
5742
assert!(layout.is_sized(), "there are no vtables for unsized types");
5843
Ok((layout.size, layout.align.abi))
5944
}
45+
46+
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
47+
pub(super) fn unpack_dyn_trait(
48+
&self,
49+
mplace: &MPlaceTy<'tcx, M::Provenance>,
50+
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
51+
) -> InterpResult<'tcx, MPlaceTy<'tcx, M::Provenance>> {
52+
assert!(
53+
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
54+
"`unpack_dyn_trait` only makes sense on `dyn*` types"
55+
);
56+
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
57+
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
58+
// This is a kind of transmute, from a place with unsized type and metadata to
59+
// a place with sized type and no metadata.
60+
let layout = self.layout_of(ty)?;
61+
let mplace = mplace.offset_with_meta(
62+
Size::ZERO,
63+
OffsetMode::Wrapping,
64+
MemPlaceMeta::None,
65+
layout,
66+
self,
67+
)?;
68+
Ok(mplace)
69+
}
70+
71+
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
72+
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
73+
&self,
74+
val: &P,
75+
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
76+
) -> InterpResult<'tcx, P> {
77+
assert!(
78+
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
79+
"`unpack_dyn_star` only makes sense on `dyn*` types"
80+
);
81+
let data = self.project_field(val, 0)?;
82+
let vtable = self.project_field(val, 1)?;
83+
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
84+
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
85+
// `data` is already the right thing but has the wrong type. So we transmute it.
86+
let layout = self.layout_of(ty)?;
87+
let data = data.transmute(layout, self)?;
88+
Ok(data)
89+
}
6090
}

compiler/rustc_const_eval/src/interpret/validity.rs

+7-11
Original file line numberDiff line numberDiff line change
@@ -343,20 +343,16 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
343343
match tail.kind() {
344344
ty::Dynamic(data, _, ty::Dyn) => {
345345
let vtable = meta.unwrap_meta().to_pointer(self.ecx)?;
346-
// Make sure it is a genuine vtable pointer.
347-
let (_dyn_ty, dyn_trait) = try_validation!(
348-
self.ecx.get_ptr_vtable(vtable),
346+
// Make sure it is a genuine vtable pointer for the right trait.
347+
try_validation!(
348+
self.ecx.get_ptr_vtable_ty(vtable, Some(data)),
349349
self.path,
350350
Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) =>
351-
InvalidVTablePtr { value: format!("{vtable}") }
351+
InvalidVTablePtr { value: format!("{vtable}") },
352+
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
353+
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
354+
},
352355
);
353-
// Make sure it is for the right trait.
354-
if dyn_trait != data.principal() {
355-
throw_validation_failure!(
356-
self.path,
357-
InvalidMetaWrongTrait { expected_trait: data, vtable_trait: dyn_trait }
358-
);
359-
}
360356
}
361357
ty::Slice(..) | ty::Str => {
362358
let _len = meta.unwrap_meta().to_target_usize(self.ecx)?;

compiler/rustc_const_eval/src/interpret/visitor.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized {
9595
// unsized values are never immediate, so we can assert_mem_place
9696
let op = v.to_op(self.ecx())?;
9797
let dest = op.assert_mem_place();
98-
let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?.0;
98+
let inner_mplace = self.ecx().unpack_dyn_trait(&dest, data)?;
9999
trace!("walk_value: dyn object layout: {:#?}", inner_mplace.layout);
100100
// recurse with the inner type
101101
return self.visit_field(v, 0, &inner_mplace.into());
@@ -104,7 +104,7 @@ pub trait ValueVisitor<'tcx, M: Machine<'tcx>>: Sized {
104104
// DynStar types. Very different from a dyn type (but strangely part of the
105105
// same variant in `TyKind`): These are pairs where the 2nd component is the
106106
// vtable, and the first component is the data (which must be ptr-sized).
107-
let data = self.ecx().unpack_dyn_star(v, data)?.0;
107+
let data = self.ecx().unpack_dyn_star(v, data)?;
108108
return self.visit_field(v, 0, &data);
109109
}
110110
// Slices do not need special handling here: they have `Array` field

0 commit comments

Comments
 (0)