Skip to content

Try to make ObligationForest more efficient #77908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 64 additions & 52 deletions compiler/rustc_data_structures/src/obligation_forest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ pub struct ObligationForest<O: ForestObligation> {
/// comments in `process_obligation` for details.
active_cache: FxHashMap<O::CacheKey, usize>,

/// A vector reused in compress(), to avoid allocating new vectors.
node_rewrites: Vec<usize>,
/// A vector reused in compress() and find_cycles_from_node(), to avoid allocating new vectors.
reused_node_vec: Vec<usize>,

obligation_tree_id_generator: ObligationTreeIdGenerator,

Expand Down Expand Up @@ -251,12 +251,22 @@ enum NodeState {
Error,
}

/// This trait allows us to have two different Outcome types:
/// - the normal one that does as little as possible
/// - one for tests that does some additional work and checking
pub trait OutcomeTrait {
type Error;
type Obligation;

fn new() -> Self;
fn mark_not_stalled(&mut self);
fn is_stalled(&self) -> bool;
fn record_completed(&mut self, outcome: &Self::Obligation);
fn record_error(&mut self, error: Self::Error);
}

#[derive(Debug)]
pub struct Outcome<O, E> {
/// Obligations that were completely evaluated, including all
/// (transitive) subobligations. Only computed if requested.
pub completed: Option<Vec<O>>,

/// Backtrace of obligations that were found to be in error.
pub errors: Vec<Error<O, E>>,

Expand All @@ -269,12 +279,29 @@ pub struct Outcome<O, E> {
pub stalled: bool,
}

/// Should `process_obligations` compute the `Outcome::completed` field of its
/// result?
#[derive(PartialEq)]
pub enum DoCompleted {
No,
Yes,
impl<O, E> OutcomeTrait for Outcome<O, E> {
type Error = Error<O, E>;
type Obligation = O;

fn new() -> Self {
Self { stalled: true, errors: vec![] }
}

fn mark_not_stalled(&mut self) {
self.stalled = false;
}

fn is_stalled(&self) -> bool {
self.stalled
}

fn record_completed(&mut self, _outcome: &Self::Obligation) {
// do nothing
}

fn record_error(&mut self, error: Self::Error) {
self.errors.push(error)
}
}

#[derive(Debug, PartialEq, Eq)]
Expand All @@ -289,7 +316,7 @@ impl<O: ForestObligation> ObligationForest<O> {
nodes: vec![],
done_cache: Default::default(),
active_cache: Default::default(),
node_rewrites: vec![],
reused_node_vec: vec![],
obligation_tree_id_generator: (0..).map(ObligationTreeId),
error_cache: Default::default(),
}
Expand Down Expand Up @@ -363,8 +390,7 @@ impl<O: ForestObligation> ObligationForest<O> {
.map(|(index, _node)| Error { error: error.clone(), backtrace: self.error_at(index) })
.collect();

let successful_obligations = self.compress(DoCompleted::Yes);
assert!(successful_obligations.unwrap().is_empty());
self.compress(|_| assert!(false));
errors
}

Expand Down Expand Up @@ -392,16 +418,12 @@ impl<O: ForestObligation> ObligationForest<O> {
/// be called in a loop until `outcome.stalled` is false.
///
/// This _cannot_ be unrolled (presently, at least).
pub fn process_obligations<P>(
&mut self,
processor: &mut P,
do_completed: DoCompleted,
) -> Outcome<O, P::Error>
pub fn process_obligations<P, OUT>(&mut self, processor: &mut P) -> OUT
where
P: ObligationProcessor<Obligation = O>,
OUT: OutcomeTrait<Obligation = O, Error = Error<O, P::Error>>,
{
let mut errors = vec![];
let mut stalled = true;
let mut outcome = OUT::new();

// Note that the loop body can append new nodes, and those new nodes
// will then be processed by subsequent iterations of the loop.
Expand Down Expand Up @@ -429,7 +451,7 @@ impl<O: ForestObligation> ObligationForest<O> {
}
ProcessResult::Changed(children) => {
// We are not (yet) stalled.
stalled = false;
outcome.mark_not_stalled();
node.state.set(NodeState::Success);

for child in children {
Expand All @@ -442,28 +464,22 @@ impl<O: ForestObligation> ObligationForest<O> {
}
}
ProcessResult::Error(err) => {
stalled = false;
errors.push(Error { error: err, backtrace: self.error_at(index) });
outcome.mark_not_stalled();
outcome.record_error(Error { error: err, backtrace: self.error_at(index) });
}
}
index += 1;
}

if stalled {
// There's no need to perform marking, cycle processing and compression when nothing
// changed.
return Outcome {
completed: if do_completed == DoCompleted::Yes { Some(vec![]) } else { None },
errors,
stalled,
};
// There's no need to perform marking, cycle processing and compression when nothing
// changed.
if !outcome.is_stalled() {
self.mark_successes();
self.process_cycles(processor);
self.compress(|obl| outcome.record_completed(obl));
}

self.mark_successes();
self.process_cycles(processor);
let completed = self.compress(do_completed);

Outcome { completed, errors, stalled }
outcome
}

/// Returns a vector of obligations for `p` and all of its
Expand Down Expand Up @@ -526,7 +542,6 @@ impl<O: ForestObligation> ObligationForest<O> {
let node = &self.nodes[index];
let state = node.state.get();
if state == NodeState::Success {
node.state.set(NodeState::Waiting);
// This call site is cold.
self.uninlined_mark_dependents_as_waiting(node);
} else {
Expand All @@ -538,17 +553,18 @@ impl<O: ForestObligation> ObligationForest<O> {
// This never-inlined function is for the cold call site.
#[inline(never)]
fn uninlined_mark_dependents_as_waiting(&self, node: &Node<O>) {
// Mark node Waiting in the cold uninlined code instead of the hot inlined
node.state.set(NodeState::Waiting);
self.inlined_mark_dependents_as_waiting(node)
}

/// Report cycles between all `Success` nodes, and convert all `Success`
/// nodes to `Done`. This must be called after `mark_successes`.
fn process_cycles<P>(&self, processor: &mut P)
fn process_cycles<P>(&mut self, processor: &mut P)
where
P: ObligationProcessor<Obligation = O>,
{
let mut stack = vec![];

let mut stack = std::mem::take(&mut self.reused_node_vec);
for (index, node) in self.nodes.iter().enumerate() {
// For some benchmarks this state test is extremely hot. It's a win
// to handle the no-op cases immediately to avoid the cost of the
Expand All @@ -559,6 +575,7 @@ impl<O: ForestObligation> ObligationForest<O> {
}

debug_assert!(stack.is_empty());
self.reused_node_vec = stack;
}

fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, index: usize)
Expand Down Expand Up @@ -591,13 +608,12 @@ impl<O: ForestObligation> ObligationForest<O> {
/// indices and hence invalidates any outstanding indices. `process_cycles`
/// must be run beforehand to remove any cycles on `Success` nodes.
#[inline(never)]
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
fn compress(&mut self, mut outcome_cb: impl FnMut(&O)) {
let orig_nodes_len = self.nodes.len();
let mut node_rewrites: Vec<_> = std::mem::take(&mut self.node_rewrites);
let mut node_rewrites: Vec<_> = std::mem::take(&mut self.reused_node_vec);
debug_assert!(node_rewrites.is_empty());
node_rewrites.extend(0..orig_nodes_len);
let mut dead_nodes = 0;
let mut removed_done_obligations: Vec<O> = vec![];

// Move removable nodes to the end, preserving the order of the
// remaining nodes.
Expand Down Expand Up @@ -627,10 +643,8 @@ impl<O: ForestObligation> ObligationForest<O> {
} else {
self.done_cache.insert(node.obligation.as_cache_key().clone());
}
if do_completed == DoCompleted::Yes {
// Extract the success stories.
removed_done_obligations.push(node.obligation.clone());
}
// Extract the success stories.
outcome_cb(&node.obligation);
node_rewrites[index] = orig_nodes_len;
dead_nodes += 1;
}
Expand All @@ -654,9 +668,7 @@ impl<O: ForestObligation> ObligationForest<O> {
}

node_rewrites.truncate(0);
self.node_rewrites = node_rewrites;

if do_completed == DoCompleted::Yes { Some(removed_done_obligations) } else { None }
self.reused_node_vec = node_rewrites;
}

fn apply_rewrites(&mut self, node_rewrites: &[usize]) {
Expand Down
Loading