@@ -11,7 +11,9 @@ use ide_db::{
11
11
helpers:: mod_path_to_ast,
12
12
imports:: insert_use:: { insert_use, ImportScope } ,
13
13
search:: { FileReference , ReferenceCategory , SearchScope } ,
14
- syntax_helpers:: node_ext:: { preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr} ,
14
+ syntax_helpers:: node_ext:: {
15
+ for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
16
+ } ,
15
17
FxIndexSet , RootDatabase ,
16
18
} ;
17
19
use itertools:: Itertools ;
@@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
78
80
} ;
79
81
80
82
let body = extraction_target ( & node, range) ?;
81
- let container_info = body. analyze_container ( & ctx. sema ) ?;
83
+ let ( container_info, contains_tail_expr ) = body. analyze_container ( & ctx. sema ) ?;
82
84
83
85
let ( locals_used, self_param) = body. analyze ( & ctx. sema ) ;
84
86
@@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
119
121
ret_ty,
120
122
body,
121
123
outliving_locals,
124
+ contains_tail_expr,
122
125
mods : container_info,
123
126
} ;
124
127
@@ -245,6 +248,8 @@ struct Function {
245
248
ret_ty : RetType ,
246
249
body : FunctionBody ,
247
250
outliving_locals : Vec < OutlivedLocal > ,
251
+ /// Whether at least one of the container's tail expr is contained in the range we're extracting.
252
+ contains_tail_expr : bool ,
248
253
mods : ContainerInfo ,
249
254
}
250
255
@@ -265,7 +270,7 @@ enum ParamKind {
265
270
MutRef ,
266
271
}
267
272
268
- #[ derive( Debug , Eq , PartialEq ) ]
273
+ #[ derive( Debug ) ]
269
274
enum FunType {
270
275
Unit ,
271
276
Single ( hir:: Type ) ,
@@ -294,7 +299,6 @@ struct ControlFlow {
294
299
#[ derive( Clone , Debug ) ]
295
300
struct ContainerInfo {
296
301
is_const : bool ,
297
- is_in_tail : bool ,
298
302
parent_loop : Option < SyntaxNode > ,
299
303
/// The function's return type, const's type etc.
300
304
ret_type : Option < hir:: Type > ,
@@ -743,7 +747,10 @@ impl FunctionBody {
743
747
( res, self_param)
744
748
}
745
749
746
- fn analyze_container ( & self , sema : & Semantics < ' _ , RootDatabase > ) -> Option < ContainerInfo > {
750
+ fn analyze_container (
751
+ & self ,
752
+ sema : & Semantics < ' _ , RootDatabase > ,
753
+ ) -> Option < ( ContainerInfo , bool ) > {
747
754
let mut ancestors = self . parent ( ) ?. ancestors ( ) ;
748
755
let infer_expr_opt = |expr| sema. type_of_expr ( & expr?) . map ( TypeInfo :: adjusted) ;
749
756
let mut parent_loop = None ;
@@ -815,28 +822,36 @@ impl FunctionBody {
815
822
}
816
823
} ;
817
824
} ;
818
- let container_tail = match expr? {
819
- ast:: Expr :: BlockExpr ( block) => block. tail_expr ( ) ,
820
- expr => Some ( expr) ,
821
- } ;
822
- let is_in_tail =
823
- container_tail. zip ( self . tail_expr ( ) ) . map_or ( false , |( container_tail, body_tail) | {
824
- container_tail. syntax ( ) . text_range ( ) . contains_range ( body_tail. syntax ( ) . text_range ( ) )
825
+
826
+ let expr = expr?;
827
+ let contains_tail_expr = if let Some ( body_tail) = self . tail_expr ( ) {
828
+ let mut contains_tail_expr = false ;
829
+ let tail_expr_range = body_tail. syntax ( ) . text_range ( ) ;
830
+ for_each_tail_expr ( & expr, & mut |e| {
831
+ if tail_expr_range. contains_range ( e. syntax ( ) . text_range ( ) ) {
832
+ contains_tail_expr = true ;
833
+ }
825
834
} ) ;
835
+ contains_tail_expr
836
+ } else {
837
+ false
838
+ } ;
826
839
827
840
let parent = self . parent ( ) ?;
828
841
let parents = generic_parents ( & parent) ;
829
842
let generic_param_lists = parents. iter ( ) . filter_map ( |it| it. generic_param_list ( ) ) . collect ( ) ;
830
843
let where_clauses = parents. iter ( ) . filter_map ( |it| it. where_clause ( ) ) . collect ( ) ;
831
844
832
- Some ( ContainerInfo {
833
- is_in_tail,
834
- is_const,
835
- parent_loop,
836
- ret_type : ty,
837
- generic_param_lists,
838
- where_clauses,
839
- } )
845
+ Some ( (
846
+ ContainerInfo {
847
+ is_const,
848
+ parent_loop,
849
+ ret_type : ty,
850
+ generic_param_lists,
851
+ where_clauses,
852
+ } ,
853
+ contains_tail_expr,
854
+ ) )
840
855
}
841
856
842
857
fn return_ty ( & self , ctx : & AssistContext < ' _ > ) -> Option < RetType > {
@@ -1368,7 +1383,7 @@ impl FlowHandler {
1368
1383
None => FlowHandler :: None ,
1369
1384
Some ( flow_kind) => {
1370
1385
let action = flow_kind. clone ( ) ;
1371
- if * ret_ty == FunType :: Unit {
1386
+ if let FunType :: Unit = ret_ty {
1372
1387
match flow_kind {
1373
1388
FlowKind :: Return ( None )
1374
1389
| FlowKind :: Break ( _, None )
@@ -1633,7 +1648,7 @@ impl Function {
1633
1648
1634
1649
fn make_ret_ty ( & self , ctx : & AssistContext < ' _ > , module : hir:: Module ) -> Option < ast:: RetType > {
1635
1650
let fun_ty = self . return_type ( ctx) ;
1636
- let handler = if self . mods . is_in_tail {
1651
+ let handler = if self . contains_tail_expr {
1637
1652
FlowHandler :: None
1638
1653
} else {
1639
1654
FlowHandler :: from_ret_ty ( self , & fun_ty)
@@ -1707,7 +1722,7 @@ fn make_body(
1707
1722
fun : & Function ,
1708
1723
) -> ast:: BlockExpr {
1709
1724
let ret_ty = fun. return_type ( ctx) ;
1710
- let handler = if fun. mods . is_in_tail {
1725
+ let handler = if fun. contains_tail_expr {
1711
1726
FlowHandler :: None
1712
1727
} else {
1713
1728
FlowHandler :: from_ret_ty ( fun, & ret_ty)
@@ -1946,7 +1961,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
1946
1961
if nested_scope. is_none ( ) {
1947
1962
if let Some ( expr) = ast:: Expr :: cast ( e. clone ( ) ) {
1948
1963
match expr {
1949
- ast:: Expr :: ReturnExpr ( return_expr) if nested_scope . is_none ( ) => {
1964
+ ast:: Expr :: ReturnExpr ( return_expr) => {
1950
1965
let expr = return_expr. expr ( ) ;
1951
1966
if let Some ( replacement) = make_rewritten_flow ( handler, expr) {
1952
1967
ted:: replace ( return_expr. syntax ( ) , replacement. syntax ( ) )
@@ -5582,6 +5597,153 @@ impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
5582
5597
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
5583
5598
t.into() + v.into()
5584
5599
}
5600
+ "# ,
5601
+ ) ;
5602
+ }
5603
+
5604
+ #[ test]
5605
+ fn non_tail_expr_of_tail_expr_loop ( ) {
5606
+ check_assist (
5607
+ extract_function,
5608
+ r#"
5609
+ pub fn f() {
5610
+ loop {
5611
+ $0if true {
5612
+ continue;
5613
+ }$0
5614
+
5615
+ if false {
5616
+ break;
5617
+ }
5618
+ }
5619
+ }
5620
+ "# ,
5621
+ r#"
5622
+ pub fn f() {
5623
+ loop {
5624
+ if let ControlFlow::Break(_) = fun_name() {
5625
+ continue;
5626
+ }
5627
+
5628
+ if false {
5629
+ break;
5630
+ }
5631
+ }
5632
+ }
5633
+
5634
+ fn $0fun_name() -> ControlFlow<()> {
5635
+ if true {
5636
+ return ControlFlow::Break(());
5637
+ }
5638
+ ControlFlow::Continue(())
5639
+ }
5640
+ "# ,
5641
+ ) ;
5642
+ }
5643
+
5644
+ #[ test]
5645
+ fn non_tail_expr_of_tail_if_block ( ) {
5646
+ // FIXME: double semicolon
5647
+ check_assist (
5648
+ extract_function,
5649
+ r#"
5650
+ //- minicore: option, try
5651
+ impl<T> core::ops::Try for Option<T> {
5652
+ type Output = T;
5653
+ type Residual = Option<!>;
5654
+ }
5655
+ impl<T> core::ops::FromResidual for Option<T> {}
5656
+
5657
+ fn f() -> Option<()> {
5658
+ if true {
5659
+ let a = $0if true {
5660
+ Some(())?
5661
+ } else {
5662
+ ()
5663
+ }$0;
5664
+ Some(a)
5665
+ } else {
5666
+ None
5667
+ }
5668
+ }
5669
+ "# ,
5670
+ r#"
5671
+ impl<T> core::ops::Try for Option<T> {
5672
+ type Output = T;
5673
+ type Residual = Option<!>;
5674
+ }
5675
+ impl<T> core::ops::FromResidual for Option<T> {}
5676
+
5677
+ fn f() -> Option<()> {
5678
+ if true {
5679
+ let a = fun_name()?;;
5680
+ Some(a)
5681
+ } else {
5682
+ None
5683
+ }
5684
+ }
5685
+
5686
+ fn $0fun_name() -> Option<()> {
5687
+ Some(if true {
5688
+ Some(())?
5689
+ } else {
5690
+ ()
5691
+ })
5692
+ }
5693
+ "# ,
5694
+ ) ;
5695
+ }
5696
+
5697
+ #[ test]
5698
+ fn tail_expr_of_tail_block_nested ( ) {
5699
+ check_assist (
5700
+ extract_function,
5701
+ r#"
5702
+ //- minicore: option, try
5703
+ impl<T> core::ops::Try for Option<T> {
5704
+ type Output = T;
5705
+ type Residual = Option<!>;
5706
+ }
5707
+ impl<T> core::ops::FromResidual for Option<T> {}
5708
+
5709
+ fn f() -> Option<()> {
5710
+ if true {
5711
+ $0{
5712
+ let a = if true {
5713
+ Some(())?
5714
+ } else {
5715
+ ()
5716
+ };
5717
+ Some(a)
5718
+ }$0
5719
+ } else {
5720
+ None
5721
+ }
5722
+ }
5723
+ "# ,
5724
+ r#"
5725
+ impl<T> core::ops::Try for Option<T> {
5726
+ type Output = T;
5727
+ type Residual = Option<!>;
5728
+ }
5729
+ impl<T> core::ops::FromResidual for Option<T> {}
5730
+
5731
+ fn f() -> Option<()> {
5732
+ if true {
5733
+ fun_name()?
5734
+ } else {
5735
+ None
5736
+ }
5737
+ }
5738
+
5739
+ fn $0fun_name() -> Option<()> {
5740
+ let a = if true {
5741
+ Some(())?
5742
+ } else {
5743
+ ()
5744
+ };
5745
+ Some(a)
5746
+ }
5585
5747
"# ,
5586
5748
) ;
5587
5749
}
0 commit comments