Skip to content

Commit b05d5db

Browse files
committed
Ensure that drop order of async fn matches fn.
This commit modifies the lowering of `async fn` arguments so that the drop order matches the equivalent `fn`. Previously, async function arguments were lowered as shown below: async fn foo(<pattern>: <ty>) { async move { } } // <-- dropped as you "exit" the fn // ...becomes... fn foo(__arg0: <ty>) { async move { let <pattern> = __arg0; } // <-- dropped as you "exit" the async block } After this PR, async function arguments will be lowered as: async fn foo(<pattern>: <ty>, <pattern>: <ty>, <pattern>: <ty>) { async move { } } // <-- dropped as you "exit" the fn // ...becomes... fn foo(__arg0: <ty>, __arg1: <ty>, __arg2: <ty>) { async move { let __arg2 = __arg2; let <pattern> = __arg2; let __arg1 = __arg1; let <pattern> = __arg1; let __arg0 = __arg0; let <pattern> = __arg0; } // <-- dropped as you "exit" the async block } If `<pattern>` is a simple ident, then it is lowered to a single `let <pattern> = <pattern>;` statement as an optimization.
1 parent 47e0803 commit b05d5db

File tree

10 files changed

+401
-221
lines changed

10 files changed

+401
-221
lines changed

src/librustc/hir/lowering.rs

+36-3
Original file line numberDiff line numberDiff line change
@@ -2996,8 +2996,33 @@ impl<'a> LoweringContext<'a> {
29962996
if let IsAsync::Async { closure_id, ref arguments, .. } = asyncness {
29972997
let mut body = body.clone();
29982998

2999+
// Async function arguments are lowered into the closure body so that they are
3000+
// captured and so that the drop order matches the equivalent non-async functions.
3001+
//
3002+
// async fn foo(<pattern>: <ty>, <pattern>: <ty>, <pattern>: <ty>) {
3003+
// async move {
3004+
// }
3005+
// }
3006+
//
3007+
// // ...becomes...
3008+
// fn foo(__arg0: <ty>, __arg1: <ty>, __arg2: <ty>) {
3009+
// async move {
3010+
// let __arg2 = __arg2;
3011+
// let <pattern> = __arg2;
3012+
// let __arg1 = __arg1;
3013+
// let <pattern> = __arg1;
3014+
// let __arg0 = __arg0;
3015+
// let <pattern> = __arg0;
3016+
// }
3017+
// }
3018+
//
3019+
// If `<pattern>` is a simple ident, then it is lowered to a single
3020+
// `let <pattern> = <pattern>;` statement as an optimization.
29993021
for a in arguments.iter().rev() {
3000-
body.stmts.insert(0, a.stmt.clone());
3022+
if let Some(pat_stmt) = a.pat_stmt.clone() {
3023+
body.stmts.insert(0, pat_stmt);
3024+
}
3025+
body.stmts.insert(0, a.move_stmt.clone());
30013026
}
30023027

30033028
let async_expr = this.make_async_expr(
@@ -3093,7 +3118,11 @@ impl<'a> LoweringContext<'a> {
30933118
let mut decl = decl.clone();
30943119
// Replace the arguments of this async function with the generated
30953120
// arguments that will be moved into the closure.
3096-
decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
3121+
for (i, a) in arguments.clone().drain(..).enumerate() {
3122+
if let Some(arg) = a.arg {
3123+
decl.inputs[i] = arg;
3124+
}
3125+
}
30973126
lower_fn(&decl)
30983127
} else {
30993128
lower_fn(decl)
@@ -3590,7 +3619,11 @@ impl<'a> LoweringContext<'a> {
35903619
let mut sig = sig.clone();
35913620
// Replace the arguments of this async function with the generated
35923621
// arguments that will be moved into the closure.
3593-
sig.decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
3622+
for (i, a) in arguments.clone().drain(..).enumerate() {
3623+
if let Some(arg) = a.arg {
3624+
sig.decl.inputs[i] = arg;
3625+
}
3626+
}
35943627
lower_method(&sig)
35953628
} else {
35963629
lower_method(sig)

src/librustc/hir/map/def_collector.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ impl<'a> DefCollector<'a> {
9494
// Walk the generated arguments for the `async fn`.
9595
for a in arguments {
9696
use visit::Visitor;
97-
this.visit_ty(&a.arg.ty);
97+
if let Some(arg) = &a.arg {
98+
this.visit_ty(&arg.ty);
99+
}
98100
}
99101

100102
// We do not invoke `walk_fn_decl` as this will walk the arguments that are being
@@ -105,10 +107,13 @@ impl<'a> DefCollector<'a> {
105107
*closure_id, DefPathData::ClosureExpr, REGULAR_SPACE, span,
106108
);
107109
this.with_parent(closure_def, |this| {
110+
use visit::Visitor;
111+
// Walk each of the generated statements before the regular block body.
108112
for a in arguments {
109-
use visit::Visitor;
110-
// Walk each of the generated statements before the regular block body.
111-
this.visit_stmt(&a.stmt);
113+
this.visit_stmt(&a.move_stmt);
114+
if let Some(pat_stmt) = &a.pat_stmt {
115+
this.visit_stmt(&pat_stmt);
116+
}
112117
}
113118

114119
visit::walk_block(this, &body);

src/librustc/lint/context.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -1334,14 +1334,19 @@ impl<'a, T: EarlyLintPass> ast_visit::Visitor<'a> for EarlyContextAndPass<'a, T>
13341334
if let ast::IsAsync::Async { ref arguments, .. } = header.asyncness.node {
13351335
for a in arguments {
13361336
// Visit the argument..
1337-
self.visit_pat(&a.arg.pat);
1338-
if let ast::ArgSource::AsyncFn(pat) = &a.arg.source {
1339-
self.visit_pat(pat);
1337+
if let Some(arg) = &a.arg {
1338+
self.visit_pat(&arg.pat);
1339+
if let ast::ArgSource::AsyncFn(pat) = &arg.source {
1340+
self.visit_pat(pat);
1341+
}
1342+
self.visit_ty(&arg.ty);
13401343
}
1341-
self.visit_ty(&a.arg.ty);
13421344

13431345
// ..and the statement.
1344-
self.visit_stmt(&a.stmt);
1346+
self.visit_stmt(&a.move_stmt);
1347+
if let Some(pat_stmt) = &a.pat_stmt {
1348+
self.visit_stmt(&pat_stmt);
1349+
}
13451350
}
13461351
}
13471352
}

src/librustc_resolve/lib.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,13 @@ impl<'a, 'tcx> Visitor<'tcx> for Resolver<'a> {
862862
// Walk the generated async arguments if this is an `async fn`, otherwise walk the
863863
// normal arguments.
864864
if let IsAsync::Async { ref arguments, .. } = asyncness {
865-
for a in arguments { add_argument(&a.arg); }
865+
for (i, a) in arguments.iter().enumerate() {
866+
if let Some(arg) = &a.arg {
867+
add_argument(&arg);
868+
} else {
869+
add_argument(&declaration.inputs[i]);
870+
}
871+
}
866872
} else {
867873
for a in &declaration.inputs { add_argument(a); }
868874
}
@@ -882,8 +888,11 @@ impl<'a, 'tcx> Visitor<'tcx> for Resolver<'a> {
882888
let mut body = body.clone();
883889
// Insert the generated statements into the body before attempting to
884890
// resolve names.
885-
for a in arguments {
886-
body.stmts.insert(0, a.stmt.clone());
891+
for a in arguments.iter().rev() {
892+
if let Some(pat_stmt) = a.pat_stmt.clone() {
893+
body.stmts.insert(0, pat_stmt);
894+
}
895+
body.stmts.insert(0, a.move_stmt.clone());
887896
}
888897
self.visit_block(&body);
889898
} else {

src/libsyntax/ast.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -1865,10 +1865,14 @@ pub enum Unsafety {
18651865
pub struct AsyncArgument {
18661866
/// `__arg0`
18671867
pub ident: Ident,
1868-
/// `__arg0: <ty>` argument to replace existing function argument `<pat>: <ty>`.
1869-
pub arg: Arg,
1870-
/// `let <pat>: <ty> = __arg0;` statement to be inserted at the start of the block.
1871-
pub stmt: Stmt,
1868+
/// `__arg0: <ty>` argument to replace existing function argument `<pat>: <ty>`. Only if
1869+
/// argument is not a simple binding.
1870+
pub arg: Option<Arg>,
1871+
/// `let __arg0 = __arg0;` statement to be inserted at the start of the block.
1872+
pub move_stmt: Stmt,
1873+
/// `let <pat> = __arg0;` statement to be inserted at the start of the block, after matching
1874+
/// move statement. Only if argument is not a simple binding.
1875+
pub pat_stmt: Option<Stmt>,
18721876
}
18731877

18741878
#[derive(Clone, RustcEncodable, RustcDecodable, Debug)]

src/libsyntax/ext/placeholders.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,10 @@ impl<'a, 'b> MutVisitor for PlaceholderExpander<'a, 'b> {
199199

200200
if let ast::IsAsync::Async { ref mut arguments, .. } = a {
201201
for argument in arguments.iter_mut() {
202-
self.next_id(&mut argument.stmt.id);
202+
self.next_id(&mut argument.move_stmt.id);
203+
if let Some(ref mut pat_stmt) = &mut argument.pat_stmt {
204+
self.next_id(&mut pat_stmt.id);
205+
}
203206
}
204207
}
205208
}

src/libsyntax/mut_visit.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,21 @@ pub fn noop_visit_asyncness<T: MutVisitor>(asyncness: &mut IsAsync, vis: &mut T)
694694
IsAsync::Async { closure_id, return_impl_trait_id, ref mut arguments } => {
695695
vis.visit_id(closure_id);
696696
vis.visit_id(return_impl_trait_id);
697-
for AsyncArgument { ident, arg, stmt } in arguments.iter_mut() {
697+
for AsyncArgument { ident, arg, pat_stmt, move_stmt } in arguments.iter_mut() {
698698
vis.visit_ident(ident);
699-
vis.visit_arg(arg);
700-
visit_clobber(stmt, |stmt| {
699+
if let Some(arg) = arg {
700+
vis.visit_arg(arg);
701+
}
702+
visit_clobber(move_stmt, |stmt| {
701703
vis.flat_map_stmt(stmt)
702704
.expect_one("expected visitor to produce exactly one item")
703705
});
706+
visit_opt(pat_stmt, |stmt| {
707+
visit_clobber(stmt, |stmt| {
708+
vis.flat_map_stmt(stmt)
709+
.expect_one("expected visitor to produce exactly one item")
710+
})
711+
});
704712
}
705713
}
706714
IsAsync::NotAsync => {}

src/libsyntax/parse/parser.rs

+48-14
Original file line numberDiff line numberDiff line change
@@ -8880,25 +8880,44 @@ impl<'a> Parser<'a> {
88808880
let name = format!("__arg{}", index);
88818881
let ident = Ident::from_str(&name);
88828882

8883+
// Check if this is a ident pattern, if so, we can optimize and avoid adding a
8884+
// `let <pat> = __argN;` statement, instead just adding a `let <pat> = <pat>;`
8885+
// statement.
8886+
let (ident, is_simple_pattern) = match input.pat.node {
8887+
PatKind::Ident(_, ident, _) => (ident, true),
8888+
_ => (ident, false),
8889+
};
8890+
88838891
// Construct an argument representing `__argN: <ty>` to replace the argument of the
8884-
// async function.
8885-
let arg = Arg {
8886-
ty: input.ty.clone(),
8887-
id,
8892+
// async function if it isn't a simple pattern.
8893+
let arg = if is_simple_pattern {
8894+
None
8895+
} else {
8896+
Some(Arg {
8897+
ty: input.ty.clone(),
8898+
id,
8899+
pat: P(Pat {
8900+
id,
8901+
node: PatKind::Ident(
8902+
BindingMode::ByValue(Mutability::Immutable), ident, None,
8903+
),
8904+
span,
8905+
}),
8906+
source: ArgSource::AsyncFn(input.pat.clone()),
8907+
})
8908+
};
8909+
8910+
// Construct a `let __argN = __argN;` statement to insert at the top of the
8911+
// async closure. This makes sure that the argument is captured by the closure and
8912+
// that the drop order is correct.
8913+
let move_local = Local {
88888914
pat: P(Pat {
88898915
id,
88908916
node: PatKind::Ident(
88918917
BindingMode::ByValue(Mutability::Immutable), ident, None,
88928918
),
88938919
span,
88948920
}),
8895-
source: ArgSource::AsyncFn(input.pat.clone()),
8896-
};
8897-
8898-
// Construct a `let <pat> = __argN;` statement to insert at the top of the
8899-
// async closure.
8900-
let local = P(Local {
8901-
pat: input.pat.clone(),
89028921
// We explicitly do not specify the type for this statement. When the user's
89038922
// argument type is `impl Trait` then this would require the
89048923
// `impl_trait_in_bindings` feature to also be present for that same type to
@@ -8918,10 +8937,25 @@ impl<'a> Parser<'a> {
89188937
span,
89198938
attrs: ThinVec::new(),
89208939
source: LocalSource::AsyncFn,
8921-
});
8922-
let stmt = Stmt { id, node: StmtKind::Local(local), span, };
8940+
};
8941+
8942+
// Construct a `let <pat> = __argN;` statement to insert at the top of the
8943+
// async closure if this isn't a simple pattern.
8944+
let pat_stmt = if is_simple_pattern {
8945+
None
8946+
} else {
8947+
Some(Stmt {
8948+
id,
8949+
node: StmtKind::Local(P(Local {
8950+
pat: input.pat.clone(),
8951+
..move_local.clone()
8952+
})),
8953+
span,
8954+
})
8955+
};
89238956

8924-
arguments.push(AsyncArgument { ident, arg, stmt });
8957+
let move_stmt = Stmt { id, node: StmtKind::Local(P(move_local)), span };
8958+
arguments.push(AsyncArgument { ident, arg, pat_stmt, move_stmt });
89258959
}
89268960
}
89278961
}

0 commit comments

Comments
 (0)