Skip to content

Commit 6ce595a

Browse files
committed
proc_macro: add an optimized CrossThread execution strategy, and a debug flag to use it
This new strategy supports avoiding waiting for a reply for noblock messages. This strategy requires using a channel-like approach (similar to the previous CrossThread1 approach). This new CrossThread execution strategy takes a type parameter for the channel to use, allowing rustc to use a more efficient channel which the proc_macro crate could not declare as a dependency.
1 parent b6f8dc9 commit 6ce595a

File tree

4 files changed

+141
-88
lines changed

4 files changed

+141
-88
lines changed

compiler/rustc_expand/src/proc_macro.rs

+31-21
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ use rustc_parse::parser::ForceCollect;
1212
use rustc_span::def_id::CrateNum;
1313
use rustc_span::{Span, DUMMY_SP};
1414

15-
const EXEC_STRATEGY: pm::bridge::server::SameThread = pm::bridge::server::SameThread;
15+
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
16+
<pm::bridge::server::MaybeCrossThread<pm::bridge::server::StdMessagePipe<_>>>::new(
17+
ecx.sess.opts.debugging_opts.proc_macro_cross_thread,
18+
)
19+
}
1620

1721
pub struct BangProcMacro {
1822
pub client: pm::bridge::client::Client<fn(pm::TokenStream) -> pm::TokenStream>,
@@ -27,14 +31,16 @@ impl base::ProcMacro for BangProcMacro {
2731
input: TokenStream,
2832
) -> Result<TokenStream, ErrorReported> {
2933
let server = proc_macro_server::Rustc::new(ecx, self.krate);
30-
self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace).map_err(|e| {
31-
let mut err = ecx.struct_span_err(span, "proc macro panicked");
32-
if let Some(s) = e.as_str() {
33-
err.help(&format!("message: {}", s));
34-
}
35-
err.emit();
36-
ErrorReported
37-
})
34+
self.client.run(&exec_strategy(ecx), server, input, ecx.ecfg.proc_macro_backtrace).map_err(
35+
|e| {
36+
let mut err = ecx.struct_span_err(span, "proc macro panicked");
37+
if let Some(s) = e.as_str() {
38+
err.help(&format!("message: {}", s));
39+
}
40+
err.emit();
41+
ErrorReported
42+
},
43+
)
3844
}
3945
}
4046

@@ -53,7 +59,7 @@ impl base::AttrProcMacro for AttrProcMacro {
5359
) -> Result<TokenStream, ErrorReported> {
5460
let server = proc_macro_server::Rustc::new(ecx, self.krate);
5561
self.client
56-
.run(&EXEC_STRATEGY, server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
62+
.run(&exec_strategy(ecx), server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
5763
.map_err(|e| {
5864
let mut err = ecx.struct_span_err(span, "custom attribute panicked");
5965
if let Some(s) = e.as_str() {
@@ -102,18 +108,22 @@ impl MultiItemModifier for ProcMacroDerive {
102108
};
103109

104110
let server = proc_macro_server::Rustc::new(ecx, self.krate);
105-
let stream =
106-
match self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace) {
107-
Ok(stream) => stream,
108-
Err(e) => {
109-
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
110-
if let Some(s) = e.as_str() {
111-
err.help(&format!("message: {}", s));
112-
}
113-
err.emit();
114-
return ExpandResult::Ready(vec![]);
111+
let stream = match self.client.run(
112+
&exec_strategy(ecx),
113+
server,
114+
input,
115+
ecx.ecfg.proc_macro_backtrace,
116+
) {
117+
Ok(stream) => stream,
118+
Err(e) => {
119+
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
120+
if let Some(s) = e.as_str() {
121+
err.help(&format!("message: {}", s));
115122
}
116-
};
123+
err.emit();
124+
return ExpandResult::Ready(vec![]);
125+
}
126+
};
117127

118128
let error_count_before = ecx.sess.parse_sess.span_diagnostic.err_count();
119129
let mut parser =

compiler/rustc_session/src/options.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,8 @@ options! {
12071207
"print layout information for each type encountered (default: no)"),
12081208
proc_macro_backtrace: bool = (false, parse_bool, [UNTRACKED],
12091209
"show backtraces for panics during proc-macro execution (default: no)"),
1210+
proc_macro_cross_thread: bool = (false, parse_bool, [UNTRACKED],
1211+
"run proc-macro code on a separate thread (default: no)"),
12101212
profile: bool = (false, parse_bool, [TRACKED],
12111213
"insert profiling code (default: no)"),
12121214
profile_closures: bool = (false, parse_no_flag, [UNTRACKED],

library/proc_macro/src/bridge/client.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,11 @@ macro_rules! client_send_impl {
312312

313313
b = bridge.dispatch.call(b);
314314

315-
let r = Result::<(), PanicMessage>::decode(&mut &b[..], &mut ());
315+
let r = if b.len() > 0 {
316+
Result::<(), PanicMessage>::decode(&mut &b[..], &mut ())
317+
} else {
318+
Ok(())
319+
};
316320

317321
bridge.cached_buffer = b;
318322

library/proc_macro/src/bridge/server.rs

+103-66
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
use super::*;
44

5+
use std::marker::PhantomData;
6+
57
// FIXME(eddyb) generate the definition of `HandleStore` in `server.rs`.
68
use super::client::HandleStore;
79

@@ -174,6 +176,50 @@ pub trait ExecutionStrategy {
174176
) -> Buffer<u8>;
175177
}
176178

179+
pub struct MaybeCrossThread<P> {
180+
cross_thread: bool,
181+
marker: PhantomData<P>,
182+
}
183+
184+
impl<P> MaybeCrossThread<P> {
185+
pub const fn new(cross_thread: bool) -> Self {
186+
MaybeCrossThread { cross_thread, marker: PhantomData }
187+
}
188+
}
189+
190+
impl<P> ExecutionStrategy for MaybeCrossThread<P>
191+
where
192+
P: MessagePipe<Buffer<u8>> + Send + 'static,
193+
{
194+
fn run_bridge_and_client<D: Copy + Send + 'static>(
195+
&self,
196+
dispatcher: &mut impl DispatcherTrait,
197+
input: Buffer<u8>,
198+
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
199+
client_data: D,
200+
force_show_panics: bool,
201+
) -> Buffer<u8> {
202+
if self.cross_thread {
203+
<CrossThread<P>>::new().run_bridge_and_client(
204+
dispatcher,
205+
input,
206+
run_client,
207+
client_data,
208+
force_show_panics,
209+
)
210+
} else {
211+
SameThread.run_bridge_and_client(
212+
dispatcher,
213+
input,
214+
run_client,
215+
client_data,
216+
force_show_panics,
217+
)
218+
}
219+
}
220+
}
221+
222+
#[derive(Default)]
177223
pub struct SameThread;
178224

179225
impl ExecutionStrategy for SameThread {
@@ -194,12 +240,18 @@ impl ExecutionStrategy for SameThread {
194240
}
195241
}
196242

197-
// NOTE(eddyb) Two implementations are provided, the second one is a bit
198-
// faster but neither is anywhere near as fast as same-thread execution.
243+
pub struct CrossThread<P>(PhantomData<P>);
199244

200-
pub struct CrossThread1;
245+
impl<P> CrossThread<P> {
246+
pub const fn new() -> Self {
247+
CrossThread(PhantomData)
248+
}
249+
}
201250

202-
impl ExecutionStrategy for CrossThread1 {
251+
impl<P> ExecutionStrategy for CrossThread<P>
252+
where
253+
P: MessagePipe<Buffer<u8>> + Send + 'static,
254+
{
203255
fn run_bridge_and_client<D: Copy + Send + 'static>(
204256
&self,
205257
dispatcher: &mut impl DispatcherTrait,
@@ -208,15 +260,18 @@ impl ExecutionStrategy for CrossThread1 {
208260
client_data: D,
209261
force_show_panics: bool,
210262
) -> Buffer<u8> {
211-
use std::sync::mpsc::channel;
212-
213-
let (req_tx, req_rx) = channel();
214-
let (res_tx, res_rx) = channel();
263+
let (mut server, mut client) = P::new();
215264

216265
let join_handle = thread::spawn(move || {
217-
let mut dispatch = |b| {
218-
req_tx.send(b).unwrap();
219-
res_rx.recv().unwrap()
266+
let mut dispatch = |b: Buffer<u8>| -> Buffer<u8> {
267+
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
268+
client.send(b);
269+
270+
if method_tag.should_wait() {
271+
client.recv().expect("server died while client waiting for reply")
272+
} else {
273+
Buffer::new()
274+
}
220275
};
221276

222277
run_client(
@@ -225,73 +280,55 @@ impl ExecutionStrategy for CrossThread1 {
225280
)
226281
});
227282

228-
for b in req_rx {
229-
res_tx.send(dispatcher.dispatch(b)).unwrap();
283+
while let Some(b) = server.recv() {
284+
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
285+
let b = dispatcher.dispatch(b);
286+
287+
if method_tag.should_wait() {
288+
server.send(b);
289+
} else if let Err(err) = <Result<(), PanicMessage>>::decode(&mut &b[..], &mut ()) {
290+
panic::resume_unwind(err.into());
291+
}
230292
}
231293

232294
join_handle.join().unwrap()
233295
}
234296
}
235297

236-
pub struct CrossThread2;
237-
238-
impl ExecutionStrategy for CrossThread2 {
239-
fn run_bridge_and_client<D: Copy + Send + 'static>(
240-
&self,
241-
dispatcher: &mut impl DispatcherTrait,
242-
input: Buffer<u8>,
243-
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
244-
client_data: D,
245-
force_show_panics: bool,
246-
) -> Buffer<u8> {
247-
use std::sync::{Arc, Mutex};
248-
249-
enum State<T> {
250-
Req(T),
251-
Res(T),
252-
}
253-
254-
let mut state = Arc::new(Mutex::new(State::Res(Buffer::new())));
298+
/// A message pipe used for communicating between server and client threads.
299+
pub trait MessagePipe<T>: Sized {
300+
/// Create a new pair of endpoints for the message pipe.
301+
fn new() -> (Self, Self);
255302

256-
let server_thread = thread::current();
257-
let state2 = state.clone();
258-
let join_handle = thread::spawn(move || {
259-
let mut dispatch = |b| {
260-
*state2.lock().unwrap() = State::Req(b);
261-
server_thread.unpark();
262-
loop {
263-
thread::park();
264-
if let State::Res(b) = &mut *state2.lock().unwrap() {
265-
break b.take();
266-
}
267-
}
268-
};
303+
/// Send a message to the other endpoint of this pipe.
304+
fn send(&mut self, value: T);
269305

270-
let r = run_client(
271-
BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics },
272-
client_data,
273-
);
306+
/// Receive a message from the other endpoint of this pipe.
307+
///
308+
/// Returns `None` if the other end of the pipe has been destroyed, and no
309+
/// message was received.
310+
fn recv(&mut self) -> Option<T>;
311+
}
274312

275-
// Wake up the server so it can exit the dispatch loop.
276-
drop(state2);
277-
server_thread.unpark();
313+
/// Implementation of `MessagePipe` using `std::sync::mpsc`
314+
pub struct StdMessagePipe<T> {
315+
tx: std::sync::mpsc::Sender<T>,
316+
rx: std::sync::mpsc::Receiver<T>,
317+
}
278318

279-
r
280-
});
319+
impl<T> MessagePipe<T> for StdMessagePipe<T> {
320+
fn new() -> (Self, Self) {
321+
let (tx1, rx1) = std::sync::mpsc::channel();
322+
let (tx2, rx2) = std::sync::mpsc::channel();
323+
(StdMessagePipe { tx: tx1, rx: rx2 }, StdMessagePipe { tx: tx2, rx: rx1 })
324+
}
281325

282-
// Check whether `state2` was dropped, to know when to stop.
283-
while Arc::get_mut(&mut state).is_none() {
284-
thread::park();
285-
let mut b = match &mut *state.lock().unwrap() {
286-
State::Req(b) => b.take(),
287-
_ => continue,
288-
};
289-
b = dispatcher.dispatch(b.take());
290-
*state.lock().unwrap() = State::Res(b);
291-
join_handle.thread().unpark();
292-
}
326+
fn send(&mut self, v: T) {
327+
self.tx.send(v).unwrap();
328+
}
293329

294-
join_handle.join().unwrap()
330+
fn recv(&mut self) -> Option<T> {
331+
self.rx.recv().ok()
295332
}
296333
}
297334

0 commit comments

Comments
 (0)