Skip to content

Commit 73150c3

Browse files
committed
fix: wrap method call exprs in parens
1 parent bce4be9 commit 73150c3

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

crates/ide-assists/src/handlers/bool_to_enum.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ fn replace_usages(
274274
{
275275
edit.replace(ty_annotation.syntax().text_range(), "Bool");
276276
replace_bool_expr(edit, initializer);
277+
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
278+
edit.replace(
279+
receiver.syntax().text_range(),
280+
format!("({} == Bool::True)", receiver),
281+
);
277282
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
278283
// for any other usage in an expression, replace it with a check that it is the true variant
279284
if let Some((record_field, expr)) = new_name
@@ -426,6 +431,17 @@ fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)
426431
Some((const_.ty()?, const_.body()?))
427432
}
428433

434+
fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
435+
let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
436+
let receiver = method_call.receiver()?;
437+
438+
if !receiver.syntax().descendants().contains(name.syntax()) {
439+
return None;
440+
}
441+
442+
Some(receiver)
443+
}
444+
429445
/// Adds the definition of the new enum before the target node.
430446
fn add_enum_def(
431447
edit: &mut SourceChangeBuilder,
@@ -1287,6 +1303,38 @@ fn main() {
12871303
)
12881304
}
12891305

1306+
#[test]
1307+
fn field_method_chain_usage() {
1308+
check_assist(
1309+
bool_to_enum,
1310+
r#"
1311+
struct Foo {
1312+
$0bool: bool,
1313+
}
1314+
1315+
fn main() {
1316+
let foo = Foo { bool: true };
1317+
1318+
foo.bool.then(|| 2);
1319+
}
1320+
"#,
1321+
r#"
1322+
#[derive(PartialEq, Eq)]
1323+
enum Bool { True, False }
1324+
1325+
struct Foo {
1326+
bool: Bool,
1327+
}
1328+
1329+
fn main() {
1330+
let foo = Foo { bool: Bool::True };
1331+
1332+
(foo.bool == Bool::True).then(|| 2);
1333+
}
1334+
"#,
1335+
)
1336+
}
1337+
12901338
#[test]
12911339
fn field_non_bool() {
12921340
cov_mark::check!(not_applicable_non_bool_field);

0 commit comments

Comments
 (0)