@@ -16,7 +16,7 @@ use syntax::{
16
16
edit_in_place:: { AttrsOwnerEdit , Indent } ,
17
17
make, HasName ,
18
18
} ,
19
- ted, AstNode , NodeOrToken , SyntaxNode , T ,
19
+ ted, AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
20
20
} ;
21
21
use text_edit:: TextRange ;
22
22
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
40
40
// ```
41
41
// ->
42
42
// ```
43
- // fn main() {
44
- // #[derive(PartialEq, Eq)]
45
- // enum Bool { True, False }
43
+ // #[derive(PartialEq, Eq)]
44
+ // enum Bool { True, False }
46
45
//
46
+ // fn main() {
47
47
// let bool = Bool::True;
48
48
//
49
49
// if bool == Bool::True {
@@ -270,6 +270,15 @@ fn replace_usages(
270
270
}
271
271
_ => ( ) ,
272
272
}
273
+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274
+ {
275
+ edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276
+ replace_bool_expr ( edit, initializer) ;
277
+ } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name) {
278
+ edit. replace (
279
+ receiver. syntax ( ) . text_range ( ) ,
280
+ format ! ( "({} == Bool::True)" , receiver) ,
281
+ ) ;
273
282
} else if new_name. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
274
283
// for any other usage in an expression, replace it with a check that it is the true variant
275
284
if let Some ( ( record_field, expr) ) = new_name
@@ -413,6 +422,26 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413
422
}
414
423
}
415
424
425
+ fn find_assoc_const_usage ( name : & ast:: NameLike ) -> Option < ( ast:: Type , ast:: Expr ) > {
426
+ let const_ = name. syntax ( ) . parent ( ) . and_then ( ast:: Const :: cast) ?;
427
+ if const_. syntax ( ) . parent ( ) . and_then ( ast:: AssocItemList :: cast) . is_none ( ) {
428
+ return None ;
429
+ }
430
+
431
+ Some ( ( const_. ty ( ) ?, const_. body ( ) ?) )
432
+ }
433
+
434
+ fn find_method_call_expr_usage ( name : & ast:: NameLike ) -> Option < ast:: Expr > {
435
+ let method_call = name. syntax ( ) . ancestors ( ) . find_map ( ast:: MethodCallExpr :: cast) ?;
436
+ let receiver = method_call. receiver ( ) ?;
437
+
438
+ if !receiver. syntax ( ) . descendants ( ) . contains ( name. syntax ( ) ) {
439
+ return None ;
440
+ }
441
+
442
+ Some ( receiver)
443
+ }
444
+
416
445
/// Adds the definition of the new enum before the target node.
417
446
fn add_enum_def (
418
447
edit : & mut SourceChangeBuilder ,
@@ -430,18 +459,31 @@ fn add_enum_def(
430
459
. any ( |module| module. nearest_non_block_module ( ctx. db ( ) ) != * target_module) ;
431
460
let enum_def = make_bool_enum ( make_enum_pub) ;
432
461
433
- let indent = IndentLevel :: from_node ( & target_node) ;
462
+ let insert_before = node_to_insert_before ( target_node) ;
463
+ let indent = IndentLevel :: from_node ( & insert_before) ;
434
464
enum_def. reindent_to ( indent) ;
435
465
436
466
ted:: insert_all (
437
- ted:: Position :: before ( & edit. make_syntax_mut ( target_node ) ) ,
467
+ ted:: Position :: before ( & edit. make_syntax_mut ( insert_before ) ) ,
438
468
vec ! [
439
469
enum_def. syntax( ) . clone( ) . into( ) ,
440
470
make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
441
471
] ,
442
472
) ;
443
473
}
444
474
475
+ /// Finds where to put the new enum definition.
476
+ /// Tries to find the ast node at the nearest module or at top-level, otherwise just
477
+ /// returns the input node.
478
+ fn node_to_insert_before ( target_node : SyntaxNode ) -> SyntaxNode {
479
+ target_node
480
+ . ancestors ( )
481
+ . take_while ( |it| !matches ! ( it. kind( ) , SyntaxKind :: MODULE | SyntaxKind :: SOURCE_FILE ) )
482
+ . filter ( |it| ast:: Item :: can_cast ( it. kind ( ) ) )
483
+ . last ( )
484
+ . unwrap_or ( target_node)
485
+ }
486
+
445
487
fn make_bool_enum ( make_pub : bool ) -> ast:: Enum {
446
488
let enum_def = make:: enum_ (
447
489
if make_pub { Some ( make:: visibility_pub ( ) ) } else { None } ,
@@ -491,10 +533,10 @@ fn main() {
491
533
}
492
534
"# ,
493
535
r#"
494
- fn main() {
495
- #[derive(PartialEq, Eq)]
496
- enum Bool { True, False }
536
+ #[derive(PartialEq, Eq)]
537
+ enum Bool { True, False }
497
538
539
+ fn main() {
498
540
let foo = Bool::True;
499
541
500
542
if foo == Bool::True {
@@ -520,10 +562,10 @@ fn main() {
520
562
}
521
563
"# ,
522
564
r#"
523
- fn main() {
524
- #[derive(PartialEq, Eq)]
525
- enum Bool { True, False }
565
+ #[derive(PartialEq, Eq)]
566
+ enum Bool { True, False }
526
567
568
+ fn main() {
527
569
let foo = Bool::True;
528
570
529
571
if foo == Bool::False {
@@ -545,10 +587,10 @@ fn main() {
545
587
}
546
588
"# ,
547
589
r#"
548
- fn main() {
549
- #[derive(PartialEq, Eq)]
550
- enum Bool { True, False }
590
+ #[derive(PartialEq, Eq)]
591
+ enum Bool { True, False }
551
592
593
+ fn main() {
552
594
let foo: Bool = Bool::False;
553
595
}
554
596
"# ,
@@ -565,10 +607,10 @@ fn main() {
565
607
}
566
608
"# ,
567
609
r#"
568
- fn main() {
569
- #[derive(PartialEq, Eq)]
570
- enum Bool { True, False }
610
+ #[derive(PartialEq, Eq)]
611
+ enum Bool { True, False }
571
612
613
+ fn main() {
572
614
let foo = if 1 == 2 { Bool::True } else { Bool::False };
573
615
}
574
616
"# ,
@@ -590,10 +632,10 @@ fn main() {
590
632
}
591
633
"# ,
592
634
r#"
593
- fn main() {
594
- #[derive(PartialEq, Eq)]
595
- enum Bool { True, False }
635
+ #[derive(PartialEq, Eq)]
636
+ enum Bool { True, False }
596
637
638
+ fn main() {
597
639
let foo = Bool::False;
598
640
let bar = true;
599
641
@@ -619,10 +661,10 @@ fn main() {
619
661
}
620
662
"# ,
621
663
r#"
622
- fn main() {
623
- #[derive(PartialEq, Eq)]
624
- enum Bool { True, False }
664
+ #[derive(PartialEq, Eq)]
665
+ enum Bool { True, False }
625
666
667
+ fn main() {
626
668
let foo = Bool::True;
627
669
628
670
if *&foo == Bool::True {
@@ -645,10 +687,10 @@ fn main() {
645
687
}
646
688
"# ,
647
689
r#"
648
- fn main() {
649
- #[derive(PartialEq, Eq)]
650
- enum Bool { True, False }
690
+ #[derive(PartialEq, Eq)]
691
+ enum Bool { True, False }
651
692
693
+ fn main() {
652
694
let foo: Bool;
653
695
foo = Bool::True;
654
696
}
@@ -671,10 +713,10 @@ fn main() {
671
713
}
672
714
"# ,
673
715
r#"
674
- fn main() {
675
- #[derive(PartialEq, Eq)]
676
- enum Bool { True, False }
716
+ #[derive(PartialEq, Eq)]
717
+ enum Bool { True, False }
677
718
719
+ fn main() {
678
720
let foo = Bool::True;
679
721
let bar = foo == Bool::False;
680
722
@@ -702,11 +744,11 @@ fn main() {
702
744
}
703
745
"# ,
704
746
r#"
747
+ #[derive(PartialEq, Eq)]
748
+ enum Bool { True, False }
749
+
705
750
fn main() {
706
751
if !"foo".chars().any(|c| {
707
- #[derive(PartialEq, Eq)]
708
- enum Bool { True, False }
709
-
710
752
let foo = Bool::True;
711
753
foo == Bool::True
712
754
}) {
@@ -1244,6 +1286,38 @@ fn main() {
1244
1286
)
1245
1287
}
1246
1288
1289
+ #[ test]
1290
+ fn field_method_chain_usage ( ) {
1291
+ check_assist (
1292
+ bool_to_enum,
1293
+ r#"
1294
+ struct Foo {
1295
+ $0bool: bool,
1296
+ }
1297
+
1298
+ fn main() {
1299
+ let foo = Foo { bool: true };
1300
+
1301
+ foo.bool.then(|| 2);
1302
+ }
1303
+ "# ,
1304
+ r#"
1305
+ #[derive(PartialEq, Eq)]
1306
+ enum Bool { True, False }
1307
+
1308
+ struct Foo {
1309
+ bool: Bool,
1310
+ }
1311
+
1312
+ fn main() {
1313
+ let foo = Foo { bool: Bool::True };
1314
+
1315
+ (foo.bool == Bool::True).then(|| 2);
1316
+ }
1317
+ "# ,
1318
+ )
1319
+ }
1320
+
1247
1321
#[ test]
1248
1322
fn field_non_bool ( ) {
1249
1323
cov_mark:: check!( not_applicable_non_bool_field) ;
@@ -1445,6 +1519,90 @@ pub mod bar {
1445
1519
)
1446
1520
}
1447
1521
1522
+ #[ test]
1523
+ fn const_in_impl_cross_file ( ) {
1524
+ check_assist (
1525
+ bool_to_enum,
1526
+ r#"
1527
+ //- /main.rs
1528
+ mod foo;
1529
+
1530
+ struct Foo;
1531
+
1532
+ impl Foo {
1533
+ pub const $0BOOL: bool = true;
1534
+ }
1535
+
1536
+ //- /foo.rs
1537
+ use crate::Foo;
1538
+
1539
+ fn foo() -> bool {
1540
+ Foo::BOOL
1541
+ }
1542
+ "# ,
1543
+ r#"
1544
+ //- /main.rs
1545
+ mod foo;
1546
+
1547
+ struct Foo;
1548
+
1549
+ #[derive(PartialEq, Eq)]
1550
+ pub enum Bool { True, False }
1551
+
1552
+ impl Foo {
1553
+ pub const BOOL: Bool = Bool::True;
1554
+ }
1555
+
1556
+ //- /foo.rs
1557
+ use crate::{Foo, Bool};
1558
+
1559
+ fn foo() -> bool {
1560
+ Foo::BOOL == Bool::True
1561
+ }
1562
+ "# ,
1563
+ )
1564
+ }
1565
+
1566
+ #[ test]
1567
+ fn const_in_trait ( ) {
1568
+ check_assist (
1569
+ bool_to_enum,
1570
+ r#"
1571
+ trait Foo {
1572
+ const $0BOOL: bool;
1573
+ }
1574
+
1575
+ impl Foo for usize {
1576
+ const BOOL: bool = true;
1577
+ }
1578
+
1579
+ fn main() {
1580
+ if <usize as Foo>::BOOL {
1581
+ println!("foo");
1582
+ }
1583
+ }
1584
+ "# ,
1585
+ r#"
1586
+ #[derive(PartialEq, Eq)]
1587
+ enum Bool { True, False }
1588
+
1589
+ trait Foo {
1590
+ const BOOL: Bool;
1591
+ }
1592
+
1593
+ impl Foo for usize {
1594
+ const BOOL: Bool = Bool::True;
1595
+ }
1596
+
1597
+ fn main() {
1598
+ if <usize as Foo>::BOOL == Bool::True {
1599
+ println!("foo");
1600
+ }
1601
+ }
1602
+ "# ,
1603
+ )
1604
+ }
1605
+
1448
1606
#[ test]
1449
1607
fn const_non_bool ( ) {
1450
1608
cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments