Skip to content

Commit 54b2df2

Browse files
committed
Implement stack-based AsyncDestrucorCtorBuilder
1 parent 6d8bf2a commit 54b2df2

File tree

1 file changed

+222
-1
lines changed
  • compiler/rustc_mir_transform/src

1 file changed

+222
-1
lines changed

compiler/rustc_mir_transform/src/shim.rs

+222-1
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ fn build_async_destructor_ctor_shim<'tcx>(
11701170
| ty::Ref(_, _, _)
11711171
| ty::FnDef(_, _)
11721172
| ty::FnPtr(_)
1173-
| ty::Never
1173+
| ty::Never => GlueStrategy::Empty,
11741174
};
11751175

11761176
let return_bb;
@@ -1607,6 +1607,227 @@ fn build_async_destructor_ctor_shim<'tcx>(
16071607
new_body(source, blocks, locals, sig.inputs().len(), span)
16081608
}
16091609

1610+
const ASYNC_DESTRUCTOR_CTOR_ARG_COUNT: usize = 1;
1611+
1612+
struct AsyncDestructorCtorShimBuilder<'tcx> {
1613+
tcx: TyCtxt<'tcx>,
1614+
def_id: DefId,
1615+
self_ty: Ty<'tcx>,
1616+
span: Span,
1617+
source_info: SourceInfo,
1618+
1619+
stack: Vec<(Local, Ty<'tcx>)>,
1620+
last_bb: BasicBlock,
1621+
1622+
locals: IndexVec<Local, LocalDecl<'tcx>>,
1623+
bbs: IndexVec<BasicBlock, BasicBlockData<'tcx>>,
1624+
1625+
// Cached stuff
1626+
chain_combinator: Option<(DefId, EarlyBinder<Ty<'tcx>>)>,
1627+
ready_unit: Option<(DefId, Ty<'tcx>)>,
1628+
}
1629+
1630+
impl<'tcx> AsyncDestructorCtorShimBuilder<'tcx> {
1631+
fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Self {
1632+
let span = tcx.def_span(def_id);
1633+
let Some(sig) = tcx.fn_sig(def_id).instantiate(tcx, &[self_ty.into()]).no_bound_vars()
1634+
else {
1635+
span_bug!(span, "async_drop_in_place_raw with bound vars for `{self_ty}`");
1636+
};
1637+
1638+
let source_info = SourceInfo::outermost(span);
1639+
1640+
debug_assert_eq!(sig.inputs().len(), ASYNC_DESTRUCTOR_CTOR_ARG_COUNT);
1641+
let locals = local_decls_for_sig(&sig, span);
1642+
1643+
AsyncDestructorCtorShimBuilder {
1644+
tcx,
1645+
def_id,
1646+
self_ty,
1647+
span,
1648+
source_info,
1649+
1650+
stack: Vec::new(),
1651+
last_bb: BasicBlock::new(0),
1652+
1653+
locals,
1654+
bbs: IndexVec::from([BasicBlockData::new(None)]),
1655+
1656+
chain_combinator: None,
1657+
ready_unit: None,
1658+
}
1659+
}
1660+
1661+
fn put_self(&mut self) {
1662+
let last_bb = &mut self.bbs[self.last_bb];
1663+
debug_assert!(last_bb.terminator.is_none());
1664+
1665+
let self_ptr = Local::new(1);
1666+
// We need to create a new local to be able to "consume" it with
1667+
// a combinator
1668+
let local = self.locals.push(self.locals[self_ptr].clone());
1669+
last_bb.statements.extend_from_slice(&[
1670+
Statement { source_info: self.source_info, kind: StatementKind::StorageLive(local) },
1671+
Statement {
1672+
source_info: self.source_info,
1673+
kind: StatementKind::Assign(Box::new((
1674+
local.into(),
1675+
Rvalue::Use(Operand::Copy(self_ptr.into())),
1676+
))),
1677+
},
1678+
]);
1679+
1680+
// We use pointee types so that they are used for instantiation
1681+
// of combinators
1682+
self.stack.push((local, self.self_ty));
1683+
}
1684+
1685+
fn put_field(&mut self, field: FieldIdx, ty: Ty<'tcx>) {
1686+
let last_bb = &mut self.bbs[self.last_bb];
1687+
debug_assert!(last_bb.terminator.is_none());
1688+
1689+
// We need to create a new local to be able to "consume" it with
1690+
// a combinator
1691+
let local = self.locals.push(LocalDecl::with_source_info(ty, self.source_info));
1692+
let self_ptr = Local::new(1);
1693+
last_bb.statements.extend_from_slice(&[
1694+
Statement { source_info: self.source_info, kind: StatementKind::StorageLive(local) },
1695+
Statement {
1696+
source_info: self.source_info,
1697+
kind: StatementKind::Assign(Box::new((
1698+
local.into(),
1699+
Rvalue::AddressOf(
1700+
Mutability::Mut,
1701+
self.tcx.mk_place_field(
1702+
self.tcx.mk_place_deref(self_ptr.into()),
1703+
field,
1704+
ty,
1705+
),
1706+
),
1707+
))),
1708+
},
1709+
]);
1710+
1711+
// We use pointee types so that they are used for instantiation
1712+
// of combinators
1713+
self.stack.push((local, ty));
1714+
}
1715+
1716+
fn ready_unit(&mut self) {
1717+
let tcx = self.tcx;
1718+
let (function, ty) = *self.ready_unit.get_or_insert_with(|| {
1719+
(
1720+
tcx.require_lang_item(LangItem::FutureReadyUnitCtor, Some(self.span)),
1721+
tcx.type_of(tcx.require_lang_item(LangItem::FutureReadyUnit, Some(self.span)))
1722+
.instantiate_identity(),
1723+
)
1724+
});
1725+
self.apply_combinator::<0, _>(function, |_, _| ty)
1726+
}
1727+
1728+
fn chain(&mut self) {
1729+
let tcx = self.tcx;
1730+
let (function, ty) = *self.chain_combinator.get_or_insert_with(|| {
1731+
(
1732+
tcx.require_lang_item(LangItem::FutureChainCtor, Some(self.span)),
1733+
tcx.type_of(tcx.require_lang_item(LangItem::FutureChain, Some(self.span))),
1734+
)
1735+
});
1736+
self.apply_combinator::<2, _>(function, |tcx, args| ty.instantiate(tcx, args))
1737+
}
1738+
1739+
fn return_(mut self) -> Body<'tcx> {
1740+
let last_bb = &mut self.bbs[self.last_bb];
1741+
debug_assert!(last_bb.terminator.is_none());
1742+
1743+
let &[(output_local, _)] = self.stack.as_slice() else {
1744+
span_bug!(
1745+
self.span,
1746+
"async destructor ctor shim builder finished with invalid number of stack items: expected 1 found {}",
1747+
self.stack.len(),
1748+
)
1749+
};
1750+
let return_local = Local::new(0);
1751+
1752+
last_bb.statements.extend_from_slice(&[
1753+
Statement {
1754+
source_info: self.source_info,
1755+
kind: StatementKind::Assign(Box::new((
1756+
return_local.into(),
1757+
Rvalue::Use(Operand::Move(output_local.into())),
1758+
))),
1759+
},
1760+
Statement {
1761+
source_info: self.source_info,
1762+
kind: StatementKind::StorageDead(output_local),
1763+
},
1764+
]);
1765+
last_bb.terminator =
1766+
Some(Terminator { source_info: self.source_info, kind: TerminatorKind::Return });
1767+
1768+
let source = MirSource::from_instance(ty::InstanceDef::AsyncDropGlueCtorShim(
1769+
self.def_id,
1770+
self.self_ty,
1771+
));
1772+
new_body(source, self.bbs, self.locals, ASYNC_DESTRUCTOR_CTOR_ARG_COUNT, self.span)
1773+
}
1774+
1775+
fn apply_combinator<const ARITY: usize, F>(&mut self, function: DefId, mut ty_combinator: F)
1776+
where
1777+
F: FnMut(TyCtxt<'tcx>, &[ty::GenericArg<'tcx>]) -> Ty<'tcx>,
1778+
{
1779+
let operands = self
1780+
.stack
1781+
.last_chunk::<ARITY>()
1782+
.expect("async destructor ctor shim combinator tried to consume too many items");
1783+
1784+
let generic_args = operands.each_ref().map(|&(_, t)| t.into());
1785+
let dest_ty = ty_combinator(self.tcx, &generic_args);
1786+
1787+
let target = self.bbs.push(BasicBlockData {
1788+
statements: {
1789+
let mut stmts = Vec::with_capacity(ARITY + 1);
1790+
stmts.extend(operands.iter().map(|&(l, _)| Statement {
1791+
source_info: self.source_info,
1792+
kind: StatementKind::StorageDead(l),
1793+
}));
1794+
stmts
1795+
},
1796+
terminator: None,
1797+
is_cleanup: false,
1798+
});
1799+
1800+
let args =
1801+
operands.iter().map(|&(l, _)| respan(self.span, Operand::Move(l.into()))).collect();
1802+
let dest =
1803+
self.locals.push(LocalDecl::with_source_info(dest_ty, self.source_info).immutable());
1804+
1805+
let last_bb = &mut self.bbs[self.last_bb];
1806+
debug_assert!(last_bb.terminator.is_none());
1807+
last_bb.statements.push(Statement {
1808+
source_info: self.source_info,
1809+
kind: StatementKind::StorageLive(dest),
1810+
});
1811+
last_bb.terminator = Some(Terminator {
1812+
source_info: self.source_info,
1813+
kind: TerminatorKind::Call {
1814+
func: Operand::function_handle(self.tcx, function, generic_args, self.span),
1815+
args,
1816+
destination: dest.into(),
1817+
target: Some(target),
1818+
// TODO: Figure out unwind (even tho they shouldn't panic?)
1819+
unwind: UnwindAction::Continue,
1820+
call_source: CallSource::Misc,
1821+
fn_span: self.span,
1822+
},
1823+
});
1824+
1825+
drop(self.stack.drain(self.stack.len() - ARITY..));
1826+
self.stack.push((dest, dest_ty));
1827+
self.last_bb = target;
1828+
}
1829+
}
1830+
16101831
fn build_construct_coroutine_by_move_shim<'tcx>(
16111832
tcx: TyCtxt<'tcx>,
16121833
coroutine_closure_def_id: DefId,

0 commit comments

Comments
 (0)