1
+ use hir:: Semantics ;
2
+ use ide_db:: RootDatabase ;
3
+ use stdx:: format_to;
1
4
use syntax:: ast:: { self , AstNode } ;
2
5
3
6
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -24,6 +27,7 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
24
27
acc : & mut Assists ,
25
28
ctx : & AssistContext < ' _ > ,
26
29
) -> Option < ( ) > {
30
+ use ArmBodyExpression :: * ;
27
31
let match_expr = ctx. find_node_at_offset :: < ast:: MatchExpr > ( ) ?;
28
32
let match_arm_list = match_expr. match_arm_list ( ) ?;
29
33
let mut arms = match_arm_list. arms ( ) ;
@@ -33,21 +37,20 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
33
37
cov_mark:: hit!( non_two_arm_match) ;
34
38
return None ;
35
39
}
36
- let first_arm_expr = first_arm. expr ( ) ;
37
- let second_arm_expr = second_arm. expr ( ) ;
40
+ let first_arm_expr = first_arm. expr ( ) ?;
41
+ let second_arm_expr = second_arm. expr ( ) ?;
42
+ let first_arm_body = is_bool_literal_expr ( & ctx. sema , & first_arm_expr) ?;
43
+ let second_arm_body = is_bool_literal_expr ( & ctx. sema , & second_arm_expr) ?;
38
44
39
- let invert_matches = if is_bool_literal_expr ( & first_arm_expr, true )
40
- && is_bool_literal_expr ( & second_arm_expr, false )
41
- {
42
- false
43
- } else if is_bool_literal_expr ( & first_arm_expr, false )
44
- && is_bool_literal_expr ( & second_arm_expr, true )
45
- {
46
- true
47
- } else {
45
+ if !matches ! (
46
+ ( & first_arm_body, & second_arm_body) ,
47
+ ( Literal ( true ) , Literal ( false ) )
48
+ | ( Literal ( false ) , Literal ( true ) )
49
+ | ( Expression ( _) , Literal ( false ) )
50
+ ) {
48
51
cov_mark:: hit!( non_invert_bool_literal_arms) ;
49
52
return None ;
50
- } ;
53
+ }
51
54
52
55
let target_range = ctx. sema . original_range ( match_expr. syntax ( ) ) . range ;
53
56
let expr = match_expr. expr ( ) ?;
@@ -59,28 +62,55 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
59
62
|builder| {
60
63
let mut arm_str = String :: new ( ) ;
61
64
if let Some ( pat) = & first_arm. pat ( ) {
62
- arm_str += & pat. to_string ( ) ;
65
+ format_to ! ( arm_str, "{ pat}" ) ;
63
66
}
64
67
if let Some ( guard) = & first_arm. guard ( ) {
65
68
arm_str += & format ! ( " {guard}" ) ;
66
69
}
67
- if invert_matches {
68
- builder. replace ( target_range, format ! ( "!matches!({expr}, {arm_str})" ) ) ;
69
- } else {
70
- builder. replace ( target_range, format ! ( "matches!({expr}, {arm_str})" ) ) ;
71
- }
70
+
71
+ let replace_with = match ( first_arm_body, second_arm_body) {
72
+ ( Literal ( true ) , Literal ( false ) ) => {
73
+ format ! ( "matches!({expr}, {arm_str})" )
74
+ }
75
+ ( Literal ( false ) , Literal ( true ) ) => {
76
+ format ! ( "!matches!({expr}, {arm_str})" )
77
+ }
78
+ ( Expression ( body_expr) , Literal ( false ) ) => {
79
+ arm_str. push_str ( match & first_arm. guard ( ) {
80
+ Some ( _) => " && " ,
81
+ _ => " if " ,
82
+ } ) ;
83
+ format ! ( "matches!({expr}, {arm_str}{body_expr})" )
84
+ }
85
+ _ => {
86
+ unreachable ! ( )
87
+ }
88
+ } ;
89
+ builder. replace ( target_range, replace_with) ;
72
90
} ,
73
91
)
74
92
}
75
93
76
- fn is_bool_literal_expr ( expr : & Option < ast:: Expr > , expect_bool : bool ) -> bool {
77
- if let Some ( ast:: Expr :: Literal ( lit) ) = expr {
94
+ enum ArmBodyExpression {
95
+ Literal ( bool ) ,
96
+ Expression ( ast:: Expr ) ,
97
+ }
98
+
99
+ fn is_bool_literal_expr (
100
+ sema : & Semantics < ' _ , RootDatabase > ,
101
+ expr : & ast:: Expr ,
102
+ ) -> Option < ArmBodyExpression > {
103
+ if let ast:: Expr :: Literal ( lit) = expr {
78
104
if let ast:: LiteralKind :: Bool ( b) = lit. kind ( ) {
79
- return b == expect_bool ;
105
+ return Some ( ArmBodyExpression :: Literal ( b ) ) ;
80
106
}
81
107
}
82
108
83
- return false ;
109
+ if !sema. type_of_expr ( expr) ?. original . is_bool ( ) {
110
+ return None ;
111
+ }
112
+
113
+ Some ( ArmBodyExpression :: Expression ( expr. clone ( ) ) )
84
114
}
85
115
86
116
#[ cfg( test) ]
@@ -121,21 +151,6 @@ fn foo(a: Option<u32>) -> bool {
121
151
) ;
122
152
}
123
153
124
- #[ test]
125
- fn not_applicable_non_bool_literal_arms ( ) {
126
- cov_mark:: check!( non_invert_bool_literal_arms) ;
127
- check_assist_not_applicable (
128
- convert_two_arm_bool_match_to_matches_macro,
129
- r#"
130
- fn foo(a: Option<u32>) -> bool {
131
- match a$0 {
132
- Some(val) => val == 3,
133
- _ => false
134
- }
135
- }
136
- "# ,
137
- ) ;
138
- }
139
154
#[ test]
140
155
fn not_applicable_both_false_arms ( ) {
141
156
cov_mark:: check!( non_invert_bool_literal_arms) ;
@@ -291,4 +306,40 @@ fn main() {
291
306
}" ,
292
307
) ;
293
308
}
309
+
310
+ #[ test]
311
+ fn convert_non_literal_bool ( ) {
312
+ check_assist (
313
+ convert_two_arm_bool_match_to_matches_macro,
314
+ r#"
315
+ fn main() {
316
+ match 0$0 {
317
+ a @ 0..15 => a == 0,
318
+ _ => false,
319
+ }
320
+ }
321
+ "# ,
322
+ r#"
323
+ fn main() {
324
+ matches!(0, a @ 0..15 if a == 0)
325
+ }
326
+ "# ,
327
+ ) ;
328
+ check_assist (
329
+ convert_two_arm_bool_match_to_matches_macro,
330
+ r#"
331
+ fn main() {
332
+ match 0$0 {
333
+ a @ 0..15 if thing() => a == 0,
334
+ _ => false,
335
+ }
336
+ }
337
+ "# ,
338
+ r#"
339
+ fn main() {
340
+ matches!(0, a @ 0..15 if thing() && a == 0)
341
+ }
342
+ "# ,
343
+ ) ;
344
+ }
294
345
}
0 commit comments