Skip to content

Add for_await to process streams using a for loop #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions async-stream-impl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
extern crate proc_macro;

use proc_macro::{TokenStream, TokenTree};
use proc_macro2::Span;
use proc_macro2::{Group, Span, TokenStream as TokenStream2, TokenTree as TokenTree2};
use quote::quote;
use syn::visit_mut::VisitMut;

Expand All @@ -19,7 +19,7 @@ struct AsyncStreamEnumHack {
}

impl AsyncStreamEnumHack {
fn parse(input: TokenStream) -> Self {
fn parse(input: TokenStream) -> syn::Result<Self> {
macro_rules! n {
($i:ident) => {
$i.next().unwrap()
Expand All @@ -44,14 +44,15 @@ impl AsyncStreamEnumHack {
n!(braces); // !

let inner = n!(braces);
let syn::Block { stmts, .. } = syn::parse(inner.clone().into()).unwrap();
let inner = replace_for_await(TokenStream2::from(TokenStream::from(inner)));
let syn::Block { stmts, .. } = syn::parse2(inner.clone())?;

let macro_ident = syn::Ident::new(
&format!("stream_{}", count_bangs(inner.into())),
Span::call_site(),
);

AsyncStreamEnumHack { stmts, macro_ident }
Ok(AsyncStreamEnumHack { stmts, macro_ident })
}
}

Expand Down Expand Up @@ -100,6 +101,42 @@ impl VisitMut for Scrub {
syn::visit_mut::visit_expr_mut(self, i);
self.is_xforming = prev;
}
syn::Expr::ForLoop(expr) => {
syn::visit_mut::visit_expr_for_loop_mut(self, expr);
// TODO: Should we allow other attributes?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just #[await]?

Also, since this is a macro anyway, we could potentially get more aggressive and transform:

for await foo in my_stream {
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just #[await]?

Sounds good to me.

for await foo in my_stream {
}

I think syn doesn't support this syntax yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(By the way, that TODO comment is about the current implementation confirming that the number of attributes is exactly one.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syn doesn't support it... but it may not be too hard to impl. You would have to manually walk the TokenStream and find for followed by await, transform that to #[await] for then run ^^.

That can be follow up work though...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c1d4aef

if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
return;
}
let syn::ExprForLoop {
attrs,
label,
pat,
expr,
body,
..
} = expr;

let attr = attrs.pop().unwrap();
if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
*i = syn::parse2(e.to_compile_error()).unwrap();
return;
}

*i = syn::parse_quote! {{
let mut __pinned = #expr;
let mut __pinned = unsafe {
::async_stream::reexport::Pin::new_unchecked(&mut __pinned)
};
#label
loop {
let #pat = match ::async_stream::reexport::next(&mut __pinned).await {
::async_stream::reexport::Some(e) => e,
::async_stream::reexport::None => break,
};
#body
}
}}
}
_ => syn::visit_mut::visit_expr_mut(self, i),
}
}
Expand All @@ -117,7 +154,10 @@ pub fn async_stream_impl(input: TokenStream) -> TokenStream {
let AsyncStreamEnumHack {
macro_ident,
mut stmts,
} = AsyncStreamEnumHack::parse(input);
} = match AsyncStreamEnumHack::parse(input) {
Ok(x) => x,
Err(e) => return e.to_compile_error().into(),
};

let mut scrub = Scrub {
is_xforming: true,
Expand Down Expand Up @@ -156,7 +196,10 @@ pub fn async_try_stream_impl(input: TokenStream) -> TokenStream {
let AsyncStreamEnumHack {
macro_ident,
mut stmts,
} = AsyncStreamEnumHack::parse(input);
} = match AsyncStreamEnumHack::parse(input) {
Ok(x) => x,
Err(e) => return e.to_compile_error().into(),
};

let mut scrub = Scrub {
is_xforming: true,
Expand Down Expand Up @@ -209,3 +252,30 @@ fn count_bangs(input: TokenStream) -> usize {

count
}

fn replace_for_await(input: TokenStream2) -> TokenStream2 {
let mut input = input.into_iter().peekable();
let mut tokens = Vec::new();

while let Some(token) = input.next() {
match token {
TokenTree2::Ident(ident) => {
match input.peek() {
Some(TokenTree2::Ident(next)) if ident == "for" && next == "await" => {
tokens.extend(quote!(#[#next]));
let _ = input.next();
}
_ => {}
}
tokens.push(ident.into());
}
TokenTree2::Group(group) => {
let stream = replace_for_await(group.stream());
tokens.push(Group::new(group.delimiter(), stream).into());
}
_ => tokens.push(token),
}
}

tokens.into_iter().collect()
}
2 changes: 1 addition & 1 deletion async-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repository = "https://github.com/tokio-rs/async-stream"
readme = "README.md"

[dependencies]
async-stream-impl = { version = "0.1.0" }
async-stream-impl = { version = "0.1.0", path = "../async-stream-impl" }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must be revert before merging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to keep the path in here IMO.

futures-core-preview = "=0.3.0-alpha.18"

[dev-dependencies]
Expand Down
11 changes: 11 additions & 0 deletions async-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
//! [`proc-macro-hack`]: https://github.com/dtolnay/proc-macro-hack/

mod async_stream;
mod next;
#[doc(hidden)]
pub mod yielder;

Expand All @@ -179,6 +180,16 @@ pub use crate::async_stream::AsyncStream;
#[doc(hidden)]
pub use async_stream_impl::{AsyncStreamHack, AsyncTryStreamHack};

#[doc(hidden)]
pub mod reexport {
#[doc(hidden)]
pub use crate::next::next;
#[doc(hidden)]
pub use std::option::Option::{None, Some};
#[doc(hidden)]
pub use std::pin::Pin;
}

/// Asynchronous stream
///
/// See [crate](index.html) documentation for more details.
Expand Down
32 changes: 32 additions & 0 deletions async-stream/src/next.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use futures_core::Stream;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

// This is equivalent to the `futures::StreamExt::next` method.
// But we want to make this crate dependency as small as possible, so we define our `next` function.
#[doc(hidden)]
pub fn next<S>(stream: &mut S) -> impl Future<Output = Option<S::Item>> + '_
where
S: Stream + Unpin,
{
Next { stream }
}

#[derive(Debug)]
struct Next<'a, S> {
stream: &'a mut S,
}

impl<S> Unpin for Next<'_, S> where S: Unpin {}

impl<S> Future for Next<'_, S>
where
S: Stream + Unpin,
{
type Output = Option<S::Item>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
25 changes: 25 additions & 0 deletions async-stream/tests/for_await.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#![feature(async_await)]

use tokio::prelude::*;

use async_stream::stream;

#[tokio::test]
async fn test() {
let s = stream! {
yield "hello";
yield "world";
};

let s = stream! {
for await x in s {
yield x.to_owned() + "!";
}
};

let values: Vec<_> = s.collect().await;

assert_eq!(2, values.len());
assert_eq!("hello!", values[0]);
assert_eq!("world!", values[1]);
}