Skip to content

Commit 024be65

Browse files
taiki-ecarllerche
authored andcommitted
Add for_await to process streams using a for loop (#2)
1 parent ba64220 commit 024be65

File tree

5 files changed

+145
-7
lines changed

5 files changed

+145
-7
lines changed

async-stream-impl/src/lib.rs

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
extern crate proc_macro;
22

33
use proc_macro::{TokenStream, TokenTree};
4-
use proc_macro2::Span;
4+
use proc_macro2::{Group, Span, TokenStream as TokenStream2, TokenTree as TokenTree2};
55
use quote::quote;
66
use syn::visit_mut::VisitMut;
77

@@ -19,7 +19,7 @@ struct AsyncStreamEnumHack {
1919
}
2020

2121
impl AsyncStreamEnumHack {
22-
fn parse(input: TokenStream) -> Self {
22+
fn parse(input: TokenStream) -> syn::Result<Self> {
2323
macro_rules! n {
2424
($i:ident) => {
2525
$i.next().unwrap()
@@ -44,14 +44,15 @@ impl AsyncStreamEnumHack {
4444
n!(braces); // !
4545

4646
let inner = n!(braces);
47-
let syn::Block { stmts, .. } = syn::parse(inner.clone().into()).unwrap();
47+
let inner = replace_for_await(TokenStream2::from(TokenStream::from(inner)));
48+
let syn::Block { stmts, .. } = syn::parse2(inner.clone())?;
4849

4950
let macro_ident = syn::Ident::new(
5051
&format!("stream_{}", count_bangs(inner.into())),
5152
Span::call_site(),
5253
);
5354

54-
AsyncStreamEnumHack { stmts, macro_ident }
55+
Ok(AsyncStreamEnumHack { stmts, macro_ident })
5556
}
5657
}
5758

@@ -100,6 +101,42 @@ impl VisitMut for Scrub {
100101
syn::visit_mut::visit_expr_mut(self, i);
101102
self.is_xforming = prev;
102103
}
104+
syn::Expr::ForLoop(expr) => {
105+
syn::visit_mut::visit_expr_for_loop_mut(self, expr);
106+
// TODO: Should we allow other attributes?
107+
if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
108+
return;
109+
}
110+
let syn::ExprForLoop {
111+
attrs,
112+
label,
113+
pat,
114+
expr,
115+
body,
116+
..
117+
} = expr;
118+
119+
let attr = attrs.pop().unwrap();
120+
if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
121+
*i = syn::parse2(e.to_compile_error()).unwrap();
122+
return;
123+
}
124+
125+
*i = syn::parse_quote! {{
126+
let mut __pinned = #expr;
127+
let mut __pinned = unsafe {
128+
::async_stream::reexport::Pin::new_unchecked(&mut __pinned)
129+
};
130+
#label
131+
loop {
132+
let #pat = match ::async_stream::reexport::next(&mut __pinned).await {
133+
::async_stream::reexport::Some(e) => e,
134+
::async_stream::reexport::None => break,
135+
};
136+
#body
137+
}
138+
}}
139+
}
103140
_ => syn::visit_mut::visit_expr_mut(self, i),
104141
}
105142
}
@@ -117,7 +154,10 @@ pub fn async_stream_impl(input: TokenStream) -> TokenStream {
117154
let AsyncStreamEnumHack {
118155
macro_ident,
119156
mut stmts,
120-
} = AsyncStreamEnumHack::parse(input);
157+
} = match AsyncStreamEnumHack::parse(input) {
158+
Ok(x) => x,
159+
Err(e) => return e.to_compile_error().into(),
160+
};
121161

122162
let mut scrub = Scrub {
123163
is_xforming: true,
@@ -156,7 +196,10 @@ pub fn async_try_stream_impl(input: TokenStream) -> TokenStream {
156196
let AsyncStreamEnumHack {
157197
macro_ident,
158198
mut stmts,
159-
} = AsyncStreamEnumHack::parse(input);
199+
} = match AsyncStreamEnumHack::parse(input) {
200+
Ok(x) => x,
201+
Err(e) => return e.to_compile_error().into(),
202+
};
160203

161204
let mut scrub = Scrub {
162205
is_xforming: true,
@@ -209,3 +252,30 @@ fn count_bangs(input: TokenStream) -> usize {
209252

210253
count
211254
}
255+
256+
fn replace_for_await(input: TokenStream2) -> TokenStream2 {
257+
let mut input = input.into_iter().peekable();
258+
let mut tokens = Vec::new();
259+
260+
while let Some(token) = input.next() {
261+
match token {
262+
TokenTree2::Ident(ident) => {
263+
match input.peek() {
264+
Some(TokenTree2::Ident(next)) if ident == "for" && next == "await" => {
265+
tokens.extend(quote!(#[#next]));
266+
let _ = input.next();
267+
}
268+
_ => {}
269+
}
270+
tokens.push(ident.into());
271+
}
272+
TokenTree2::Group(group) => {
273+
let stream = replace_for_await(group.stream());
274+
tokens.push(Group::new(group.delimiter(), stream).into());
275+
}
276+
_ => tokens.push(token),
277+
}
278+
}
279+
280+
tokens.into_iter().collect()
281+
}

async-stream/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repository = "https://github.com/tokio-rs/async-stream"
2020
readme = "README.md"
2121

2222
[dependencies]
23-
async-stream-impl = { version = "0.1.0" }
23+
async-stream-impl = { version = "0.1.0", path = "../async-stream-impl" }
2424
futures-core-preview = "=0.3.0-alpha.18"
2525

2626
[dev-dependencies]

async-stream/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
//! [`proc-macro-hack`]: https://github.com/dtolnay/proc-macro-hack/
170170
171171
mod async_stream;
172+
mod next;
172173
#[doc(hidden)]
173174
pub mod yielder;
174175

@@ -179,6 +180,16 @@ pub use crate::async_stream::AsyncStream;
179180
#[doc(hidden)]
180181
pub use async_stream_impl::{AsyncStreamHack, AsyncTryStreamHack};
181182

183+
#[doc(hidden)]
184+
pub mod reexport {
185+
#[doc(hidden)]
186+
pub use crate::next::next;
187+
#[doc(hidden)]
188+
pub use std::option::Option::{None, Some};
189+
#[doc(hidden)]
190+
pub use std::pin::Pin;
191+
}
192+
182193
/// Asynchronous stream
183194
///
184195
/// See [crate](index.html) documentation for more details.

async-stream/src/next.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use futures_core::Stream;
2+
use std::future::Future;
3+
use std::pin::Pin;
4+
use std::task::{Context, Poll};
5+
6+
// This is equivalent to the `futures::StreamExt::next` method.
7+
// But we want to make this crate dependency as small as possible, so we define our `next` function.
8+
#[doc(hidden)]
9+
pub fn next<S>(stream: &mut S) -> impl Future<Output = Option<S::Item>> + '_
10+
where
11+
S: Stream + Unpin,
12+
{
13+
Next { stream }
14+
}
15+
16+
#[derive(Debug)]
17+
struct Next<'a, S> {
18+
stream: &'a mut S,
19+
}
20+
21+
impl<S> Unpin for Next<'_, S> where S: Unpin {}
22+
23+
impl<S> Future for Next<'_, S>
24+
where
25+
S: Stream + Unpin,
26+
{
27+
type Output = Option<S::Item>;
28+
29+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
30+
Pin::new(&mut self.stream).poll_next(cx)
31+
}
32+
}

async-stream/tests/for_await.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#![feature(async_await)]
2+
3+
use tokio::prelude::*;
4+
5+
use async_stream::stream;
6+
7+
#[tokio::test]
8+
async fn test() {
9+
let s = stream! {
10+
yield "hello";
11+
yield "world";
12+
};
13+
14+
let s = stream! {
15+
for await x in s {
16+
yield x.to_owned() + "!";
17+
}
18+
};
19+
20+
let values: Vec<_> = s.collect().await;
21+
22+
assert_eq!(2, values.len());
23+
assert_eq!("hello!", values[0]);
24+
assert_eq!("world!", values[1]);
25+
}

0 commit comments

Comments
 (0)