Skip to content

Commit 87e2c31

Browse files
committed
Auto merge of #15667 - rmehri01:bool_to_enum_top_level, r=Veykril
fix: make bool_to_enum assist create enum at top-level This pr makes the `bool_to_enum` assist create the `enum` at the next closest module block or at top-level, which fixes a few tricky cases such as with an associated `const` in a trait or module: ```rust trait Foo { const $0BOOL: bool; } impl Foo for usize { const BOOL: bool = true; } fn main() { if <usize as Foo>::BOOL { println!("foo"); } } ``` Which now properly produces: ```rust #[derive(PartialEq, Eq)] enum Bool { True, False } trait Foo { const BOOL: Bool; } impl Foo for usize { const BOOL: Bool = Bool::True; } fn main() { if <usize as Foo>::BOOL == Bool::True { println!("foo"); } } ``` I also think it's a bit nicer, especially for local variables, but didn't really know to do it in the first PR :)
2 parents f19479a + 1b3e5b2 commit 87e2c31

File tree

2 files changed

+194
-36
lines changed

2 files changed

+194
-36
lines changed

crates/ide-assists/src/handlers/bool_to_enum.rs

+191-33
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use syntax::{
1616
edit_in_place::{AttrsOwnerEdit, Indent},
1717
make, HasName,
1818
},
19-
ted, AstNode, NodeOrToken, SyntaxNode, T,
19+
ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
2020
};
2121
use text_edit::TextRange;
2222

@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
4040
// ```
4141
// ->
4242
// ```
43-
// fn main() {
44-
// #[derive(PartialEq, Eq)]
45-
// enum Bool { True, False }
43+
// #[derive(PartialEq, Eq)]
44+
// enum Bool { True, False }
4645
//
46+
// fn main() {
4747
// let bool = Bool::True;
4848
//
4949
// if bool == Bool::True {
@@ -270,6 +270,15 @@ fn replace_usages(
270270
}
271271
_ => (),
272272
}
273+
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
274+
{
275+
edit.replace(ty_annotation.syntax().text_range(), "Bool");
276+
replace_bool_expr(edit, initializer);
277+
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
278+
edit.replace(
279+
receiver.syntax().text_range(),
280+
format!("({} == Bool::True)", receiver),
281+
);
273282
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
274283
// for any other usage in an expression, replace it with a check that it is the true variant
275284
if let Some((record_field, expr)) = new_name
@@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413422
}
414423
}
415424

425+
fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
426+
let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
427+
if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
428+
return None;
429+
}
430+
431+
Some((const_.ty()?, const_.body()?))
432+
}
433+
434+
fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
435+
let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
436+
let receiver = method_call.receiver()?;
437+
438+
if !receiver.syntax().descendants().contains(name.syntax()) {
439+
return None;
440+
}
441+
442+
Some(receiver)
443+
}
444+
416445
/// Adds the definition of the new enum before the target node.
417446
fn add_enum_def(
418447
edit: &mut SourceChangeBuilder,
@@ -430,18 +459,31 @@ fn add_enum_def(
430459
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
431460
let enum_def = make_bool_enum(make_enum_pub);
432461

433-
let indent = IndentLevel::from_node(&target_node);
462+
let insert_before = node_to_insert_before(target_node);
463+
let indent = IndentLevel::from_node(&insert_before);
434464
enum_def.reindent_to(indent);
435465

436466
ted::insert_all(
437-
ted::Position::before(&edit.make_syntax_mut(target_node)),
467+
ted::Position::before(&edit.make_syntax_mut(insert_before)),
438468
vec![
439469
enum_def.syntax().clone().into(),
440470
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
441471
],
442472
);
443473
}
444474

475+
/// Finds where to put the new enum definition.
476+
/// Tries to find the ast node at the nearest module or at top-level, otherwise just
477+
/// returns the input node.
478+
fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
479+
target_node
480+
.ancestors()
481+
.take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
482+
.filter(|it| ast::Item::can_cast(it.kind()))
483+
.last()
484+
.unwrap_or(target_node)
485+
}
486+
445487
fn make_bool_enum(make_pub: bool) -> ast::Enum {
446488
let enum_def = make::enum_(
447489
if make_pub { Some(make::visibility_pub()) } else { None },
@@ -491,10 +533,10 @@ fn main() {
491533
}
492534
"#,
493535
r#"
494-
fn main() {
495-
#[derive(PartialEq, Eq)]
496-
enum Bool { True, False }
536+
#[derive(PartialEq, Eq)]
537+
enum Bool { True, False }
497538
539+
fn main() {
498540
let foo = Bool::True;
499541
500542
if foo == Bool::True {
@@ -520,10 +562,10 @@ fn main() {
520562
}
521563
"#,
522564
r#"
523-
fn main() {
524-
#[derive(PartialEq, Eq)]
525-
enum Bool { True, False }
565+
#[derive(PartialEq, Eq)]
566+
enum Bool { True, False }
526567
568+
fn main() {
527569
let foo = Bool::True;
528570
529571
if foo == Bool::False {
@@ -545,10 +587,10 @@ fn main() {
545587
}
546588
"#,
547589
r#"
548-
fn main() {
549-
#[derive(PartialEq, Eq)]
550-
enum Bool { True, False }
590+
#[derive(PartialEq, Eq)]
591+
enum Bool { True, False }
551592
593+
fn main() {
552594
let foo: Bool = Bool::False;
553595
}
554596
"#,
@@ -565,10 +607,10 @@ fn main() {
565607
}
566608
"#,
567609
r#"
568-
fn main() {
569-
#[derive(PartialEq, Eq)]
570-
enum Bool { True, False }
610+
#[derive(PartialEq, Eq)]
611+
enum Bool { True, False }
571612
613+
fn main() {
572614
let foo = if 1 == 2 { Bool::True } else { Bool::False };
573615
}
574616
"#,
@@ -590,10 +632,10 @@ fn main() {
590632
}
591633
"#,
592634
r#"
593-
fn main() {
594-
#[derive(PartialEq, Eq)]
595-
enum Bool { True, False }
635+
#[derive(PartialEq, Eq)]
636+
enum Bool { True, False }
596637
638+
fn main() {
597639
let foo = Bool::False;
598640
let bar = true;
599641
@@ -619,10 +661,10 @@ fn main() {
619661
}
620662
"#,
621663
r#"
622-
fn main() {
623-
#[derive(PartialEq, Eq)]
624-
enum Bool { True, False }
664+
#[derive(PartialEq, Eq)]
665+
enum Bool { True, False }
625666
667+
fn main() {
626668
let foo = Bool::True;
627669
628670
if *&foo == Bool::True {
@@ -645,10 +687,10 @@ fn main() {
645687
}
646688
"#,
647689
r#"
648-
fn main() {
649-
#[derive(PartialEq, Eq)]
650-
enum Bool { True, False }
690+
#[derive(PartialEq, Eq)]
691+
enum Bool { True, False }
651692
693+
fn main() {
652694
let foo: Bool;
653695
foo = Bool::True;
654696
}
@@ -671,10 +713,10 @@ fn main() {
671713
}
672714
"#,
673715
r#"
674-
fn main() {
675-
#[derive(PartialEq, Eq)]
676-
enum Bool { True, False }
716+
#[derive(PartialEq, Eq)]
717+
enum Bool { True, False }
677718
719+
fn main() {
678720
let foo = Bool::True;
679721
let bar = foo == Bool::False;
680722
@@ -702,11 +744,11 @@ fn main() {
702744
}
703745
"#,
704746
r#"
747+
#[derive(PartialEq, Eq)]
748+
enum Bool { True, False }
749+
705750
fn main() {
706751
if !"foo".chars().any(|c| {
707-
#[derive(PartialEq, Eq)]
708-
enum Bool { True, False }
709-
710752
let foo = Bool::True;
711753
foo == Bool::True
712754
}) {
@@ -1244,6 +1286,38 @@ fn main() {
12441286
)
12451287
}
12461288

1289+
#[test]
1290+
fn field_method_chain_usage() {
1291+
check_assist(
1292+
bool_to_enum,
1293+
r#"
1294+
struct Foo {
1295+
$0bool: bool,
1296+
}
1297+
1298+
fn main() {
1299+
let foo = Foo { bool: true };
1300+
1301+
foo.bool.then(|| 2);
1302+
}
1303+
"#,
1304+
r#"
1305+
#[derive(PartialEq, Eq)]
1306+
enum Bool { True, False }
1307+
1308+
struct Foo {
1309+
bool: Bool,
1310+
}
1311+
1312+
fn main() {
1313+
let foo = Foo { bool: Bool::True };
1314+
1315+
(foo.bool == Bool::True).then(|| 2);
1316+
}
1317+
"#,
1318+
)
1319+
}
1320+
12471321
#[test]
12481322
fn field_non_bool() {
12491323
cov_mark::check!(not_applicable_non_bool_field);
@@ -1445,6 +1519,90 @@ pub mod bar {
14451519
)
14461520
}
14471521

1522+
#[test]
1523+
fn const_in_impl_cross_file() {
1524+
check_assist(
1525+
bool_to_enum,
1526+
r#"
1527+
//- /main.rs
1528+
mod foo;
1529+
1530+
struct Foo;
1531+
1532+
impl Foo {
1533+
pub const $0BOOL: bool = true;
1534+
}
1535+
1536+
//- /foo.rs
1537+
use crate::Foo;
1538+
1539+
fn foo() -> bool {
1540+
Foo::BOOL
1541+
}
1542+
"#,
1543+
r#"
1544+
//- /main.rs
1545+
mod foo;
1546+
1547+
struct Foo;
1548+
1549+
#[derive(PartialEq, Eq)]
1550+
pub enum Bool { True, False }
1551+
1552+
impl Foo {
1553+
pub const BOOL: Bool = Bool::True;
1554+
}
1555+
1556+
//- /foo.rs
1557+
use crate::{Foo, Bool};
1558+
1559+
fn foo() -> bool {
1560+
Foo::BOOL == Bool::True
1561+
}
1562+
"#,
1563+
)
1564+
}
1565+
1566+
#[test]
1567+
fn const_in_trait() {
1568+
check_assist(
1569+
bool_to_enum,
1570+
r#"
1571+
trait Foo {
1572+
const $0BOOL: bool;
1573+
}
1574+
1575+
impl Foo for usize {
1576+
const BOOL: bool = true;
1577+
}
1578+
1579+
fn main() {
1580+
if <usize as Foo>::BOOL {
1581+
println!("foo");
1582+
}
1583+
}
1584+
"#,
1585+
r#"
1586+
#[derive(PartialEq, Eq)]
1587+
enum Bool { True, False }
1588+
1589+
trait Foo {
1590+
const BOOL: Bool;
1591+
}
1592+
1593+
impl Foo for usize {
1594+
const BOOL: Bool = Bool::True;
1595+
}
1596+
1597+
fn main() {
1598+
if <usize as Foo>::BOOL == Bool::True {
1599+
println!("foo");
1600+
}
1601+
}
1602+
"#,
1603+
)
1604+
}
1605+
14481606
#[test]
14491607
fn const_non_bool() {
14501608
cov_mark::check!(not_applicable_non_bool_const);

crates/ide-assists/src/tests/generated.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ fn main() {
294294
}
295295
"#####,
296296
r#####"
297-
fn main() {
298-
#[derive(PartialEq, Eq)]
299-
enum Bool { True, False }
297+
#[derive(PartialEq, Eq)]
298+
enum Bool { True, False }
300299
300+
fn main() {
301301
let bool = Bool::True;
302302
303303
if bool == Bool::True {

0 commit comments

Comments
 (0)