Skip to content

Commit 9973b11

Browse files
committed
Auto merge of rust-lang#15012 - lowr:patch/generate-fn-async-ret-ty, r=HKalbasi
Infer return type for async function in `generate_function` Part of rust-lang#10122 In `generate_function` assist, when we infer the return type of async function we're generating, we should retrieve the type of parent await expression rather than the call expression itself.
2 parents 9c03aa1 + 32768fe commit 9973b11

File tree

1 file changed

+84
-18
lines changed

1 file changed

+84
-18
lines changed

crates/ide-assists/src/handlers/generate_function.rs

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,9 @@ impl FunctionBuilder {
291291
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
292292
let is_async = await_expr.is_some();
293293

294-
let (ret_type, should_focus_return_type) = make_return_type(
295-
ctx,
296-
&ast::Expr::CallExpr(call.clone()),
297-
target_module,
298-
&mut necessary_generic_params,
299-
);
294+
let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
295+
let (ret_type, should_focus_return_type) =
296+
make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
300297

301298
let (generic_param_list, where_clause) =
302299
fn_generic_params(ctx, necessary_generic_params, &target)?;
@@ -338,12 +335,9 @@ impl FunctionBuilder {
338335
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
339336
let is_async = await_expr.is_some();
340337

341-
let (ret_type, should_focus_return_type) = make_return_type(
342-
ctx,
343-
&ast::Expr::MethodCallExpr(call.clone()),
344-
target_module,
345-
&mut necessary_generic_params,
346-
);
338+
let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
339+
let (ret_type, should_focus_return_type) =
340+
make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
347341

348342
let (generic_param_list, where_clause) =
349343
fn_generic_params(ctx, necessary_generic_params, &target)?;
@@ -429,12 +423,12 @@ impl FunctionBuilder {
429423
/// user can change the `todo!` function body.
430424
fn make_return_type(
431425
ctx: &AssistContext<'_>,
432-
call: &ast::Expr,
426+
expr: &ast::Expr,
433427
target_module: Module,
434428
necessary_generic_params: &mut FxHashSet<hir::GenericParam>,
435429
) -> (Option<ast::RetType>, bool) {
436430
let (ret_ty, should_focus_return_type) = {
437-
match ctx.sema.type_of_expr(call).map(TypeInfo::original) {
431+
match ctx.sema.type_of_expr(expr).map(TypeInfo::original) {
438432
Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true),
439433
None => (Some(make::ty_placeholder()), true),
440434
Some(ty) if ty.is_unit() => (None, false),
@@ -2268,13 +2262,13 @@ impl Foo {
22682262
check_assist(
22692263
generate_function,
22702264
r"
2271-
fn foo() {
2272-
$0bar(42).await();
2265+
async fn foo() {
2266+
$0bar(42).await;
22732267
}
22742268
",
22752269
r"
2276-
fn foo() {
2277-
bar(42).await();
2270+
async fn foo() {
2271+
bar(42).await;
22782272
}
22792273
22802274
async fn bar(arg: i32) ${0:-> _} {
@@ -2284,6 +2278,28 @@ async fn bar(arg: i32) ${0:-> _} {
22842278
)
22852279
}
22862280

2281+
#[test]
2282+
fn return_type_for_async_fn() {
2283+
check_assist(
2284+
generate_function,
2285+
r"
2286+
//- minicore: result
2287+
async fn foo() {
2288+
if Err(()) = $0bar(42).await {}
2289+
}
2290+
",
2291+
r"
2292+
async fn foo() {
2293+
if Err(()) = bar(42).await {}
2294+
}
2295+
2296+
async fn bar(arg: i32) -> Result<_, ()> {
2297+
${0:todo!()}
2298+
}
2299+
",
2300+
);
2301+
}
2302+
22872303
#[test]
22882304
fn create_method() {
22892305
check_assist(
@@ -2401,6 +2417,31 @@ fn foo() {S.bar();}
24012417
)
24022418
}
24032419

2420+
#[test]
2421+
fn create_async_method() {
2422+
check_assist(
2423+
generate_function,
2424+
r"
2425+
//- minicore: result
2426+
struct S;
2427+
async fn foo() {
2428+
if let Err(()) = S.$0bar(42).await {}
2429+
}
2430+
",
2431+
r"
2432+
struct S;
2433+
impl S {
2434+
async fn bar(&self, arg: i32) -> Result<_, ()> {
2435+
${0:todo!()}
2436+
}
2437+
}
2438+
async fn foo() {
2439+
if let Err(()) = S.bar(42).await {}
2440+
}
2441+
",
2442+
)
2443+
}
2444+
24042445
#[test]
24052446
fn create_static_method() {
24062447
check_assist(
@@ -2421,6 +2462,31 @@ fn foo() {S::bar();}
24212462
)
24222463
}
24232464

2465+
#[test]
2466+
fn create_async_static_method() {
2467+
check_assist(
2468+
generate_function,
2469+
r"
2470+
//- minicore: result
2471+
struct S;
2472+
async fn foo() {
2473+
if let Err(()) = S::$0bar(42).await {}
2474+
}
2475+
",
2476+
r"
2477+
struct S;
2478+
impl S {
2479+
async fn bar(arg: i32) -> Result<_, ()> {
2480+
${0:todo!()}
2481+
}
2482+
}
2483+
async fn foo() {
2484+
if let Err(()) = S::bar(42).await {}
2485+
}
2486+
",
2487+
)
2488+
}
2489+
24242490
#[test]
24252491
fn create_generic_static_method() {
24262492
check_assist(

0 commit comments

Comments
 (0)