@@ -53,6 +53,7 @@ use crate::deref_separator::deref_finder;
53
53
use crate :: simplify;
54
54
use crate :: util:: expand_aggregate;
55
55
use crate :: MirPass ;
56
+ use hir:: GeneratorKind ;
56
57
use rustc_data_structures:: fx:: FxHashMap ;
57
58
use rustc_hir as hir;
58
59
use rustc_hir:: lang_items:: LangItem ;
@@ -215,6 +216,7 @@ struct SuspensionPoint<'tcx> {
215
216
216
217
struct TransformVisitor < ' tcx > {
217
218
tcx : TyCtxt < ' tcx > ,
219
+ is_async_kind : bool ,
218
220
state_adt_ref : AdtDef < ' tcx > ,
219
221
state_substs : SubstsRef < ' tcx > ,
220
222
@@ -239,28 +241,30 @@ struct TransformVisitor<'tcx> {
239
241
}
240
242
241
243
impl < ' tcx > TransformVisitor < ' tcx > {
242
- // Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single
243
- // element tuple variants, so we can just write to the downcasted first field and then set the
244
+ // Make a `GeneratorState` or `Poll` variant assignment.
245
+ //
246
+ // `core::ops::GeneratorState` only has single element tuple variants,
247
+ // so we can just write to the downcasted first field and then set the
244
248
// discriminant to the appropriate variant.
245
- fn make_state (
246
- & self ,
247
- idx : VariantIdx ,
248
- val : Operand < ' tcx > ,
249
- source_info : SourceInfo ,
250
- ) -> impl Iterator < Item = Statement < ' tcx > > {
249
+ fn make_state ( & self , idx : VariantIdx ) -> ( AggregateKind < ' tcx > , Option < Ty < ' tcx > > ) {
251
250
let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_substs , None , None ) ;
251
+
252
+ // `Poll::Pending`
253
+ if self . is_async_kind && idx == VariantIdx :: new ( 1 ) {
254
+ assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
255
+
256
+ return ( kind, None ) ;
257
+ }
258
+
259
+ // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
252
260
assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
261
+
253
262
let ty = self
254
263
. tcx
255
264
. bound_type_of ( self . state_adt_ref . variant ( idx) . fields [ 0 ] . did )
256
265
. subst ( self . tcx , self . state_substs ) ;
257
- expand_aggregate (
258
- Place :: return_place ( ) ,
259
- std:: iter:: once ( ( val, ty) ) ,
260
- kind,
261
- source_info,
262
- self . tcx ,
263
- )
266
+
267
+ ( kind, Some ( ty) )
264
268
}
265
269
266
270
// Create a Place referencing a generator struct field
@@ -331,22 +335,44 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
331
335
} ) ;
332
336
333
337
let ret_val = match data. terminator ( ) . kind {
334
- TerminatorKind :: Return => Some ( (
335
- VariantIdx :: new ( 1 ) ,
336
- None ,
337
- Operand :: Move ( Place :: from ( self . new_ret_local ) ) ,
338
- None ,
339
- ) ) ,
338
+ TerminatorKind :: Return => {
339
+ Some ( ( true , None , Operand :: Move ( Place :: from ( self . new_ret_local ) ) , None ) )
340
+ }
340
341
TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
341
- Some ( ( VariantIdx :: new ( 0 ) , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
342
+ Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
342
343
}
343
344
_ => None ,
344
345
} ;
345
346
346
- if let Some ( ( state_idx , resume, v, drop) ) = ret_val {
347
+ if let Some ( ( is_return , resume, v, drop) ) = ret_val {
347
348
let source_info = data. terminator ( ) . source_info ;
348
349
// We must assign the value first in case it gets declared dead below
349
- data. statements . extend ( self . make_state ( state_idx, v, source_info) ) ;
350
+ let state_idx = VariantIdx :: new ( match ( is_return, self . is_async_kind ) {
351
+ ( true , false ) => 1 , // GeneratorState::Complete
352
+ ( false , false ) => 0 , // GeneratorState::Yielded
353
+ ( true , true ) => 0 , // Poll::Ready
354
+ ( false , true ) => 1 , // Poll::Pending
355
+ } ) ;
356
+ let ( kind, ty) = self . make_state ( state_idx) ;
357
+ if let Some ( ty) = ty {
358
+ data. statements . extend ( expand_aggregate (
359
+ Place :: return_place ( ) ,
360
+ std:: iter:: once ( ( v, ty) ) ,
361
+ kind,
362
+ source_info,
363
+ self . tcx ,
364
+ ) ) ;
365
+ } else {
366
+ // TODO: assert `val` is nil
367
+ data. statements . extend ( expand_aggregate (
368
+ Place :: return_place ( ) ,
369
+ std:: iter:: empty ( ) ,
370
+ kind,
371
+ source_info,
372
+ self . tcx ,
373
+ ) ) ;
374
+ }
375
+
350
376
let state = if let Some ( ( resume, mut resume_arg) ) = resume {
351
377
// Yield
352
378
let state = RESERVED_VARIANTS + self . suspension_points . len ( ) ;
@@ -1268,10 +1294,20 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1268
1294
}
1269
1295
} ;
1270
1296
1271
- // Compute GeneratorState<yield_ty, return_ty>
1272
- let state_did = tcx. require_lang_item ( LangItem :: GeneratorState , None ) ;
1273
- let state_adt_ref = tcx. adt_def ( state_did) ;
1274
- let state_substs = tcx. intern_substs ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1297
+ let is_async_kind = body. generator_kind ( ) . unwrap ( ) != GeneratorKind :: Gen ;
1298
+ let ( state_adt_ref, state_substs) = if is_async_kind {
1299
+ // Compute Poll<return_ty>
1300
+ let state_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1301
+ let state_adt_ref = tcx. adt_def ( state_did) ;
1302
+ let state_substs = tcx. intern_substs ( & [ body. return_ty ( ) . into ( ) ] ) ;
1303
+ ( state_adt_ref, state_substs)
1304
+ } else {
1305
+ // Compute GeneratorState<yield_ty, return_ty>
1306
+ let state_did = tcx. require_lang_item ( LangItem :: GeneratorState , None ) ;
1307
+ let state_adt_ref = tcx. adt_def ( state_did) ;
1308
+ let state_substs = tcx. intern_substs ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1309
+ ( state_adt_ref, state_substs)
1310
+ } ;
1275
1311
let ret_ty = tcx. mk_adt ( state_adt_ref, state_substs) ;
1276
1312
1277
1313
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
@@ -1327,9 +1363,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1327
1363
// Run the transformation which converts Places from Local to generator struct
1328
1364
// accesses for locals in `remap`.
1329
1365
// It also rewrites `return x` and `yield y` as writing a new generator state and returning
1330
- // GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively.
1366
+ // either GeneratorState::Complete(x) and GeneratorState::Yielded(y),
1367
+ // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
1331
1368
let mut transform = TransformVisitor {
1332
1369
tcx,
1370
+ is_async_kind,
1333
1371
state_adt_ref,
1334
1372
state_substs,
1335
1373
remap,
@@ -1367,7 +1405,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1367
1405
1368
1406
body. generator . as_mut ( ) . unwrap ( ) . generator_drop = Some ( drop_shim) ;
1369
1407
1370
- // Create the Generator::resume function
1408
+ // Create the Generator::resume / Future::poll function
1371
1409
create_generator_resume_function ( tcx, transform, body, can_return) ;
1372
1410
1373
1411
// Run derefer to fix Derefs that are not in the first place
0 commit comments