Skip to content

NLL infer #45155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 14, 2017
222 changes: 222 additions & 0 deletions src/librustc_mir/transform/nll/infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use super::{Region, RegionIndex};
use std::mem;
use rustc::infer::InferCtxt;
use rustc::mir::{Location, Mir};
use rustc_data_structures::indexed_vec::{Idx, IndexVec};
use rustc_data_structures::fx::FxHashSet;

pub struct InferenceContext {
definitions: IndexVec<RegionIndex, VarDefinition>,
constraints: IndexVec<ConstraintIndex, Constraint>,
errors: IndexVec<InferenceErrorIndex, InferenceError>,
}

pub struct InferenceError {
pub constraint_point: Location,
pub name: (), // FIXME(nashenas88) RegionName
}

newtype_index!(InferenceErrorIndex);

struct VarDefinition {
name: (), // FIXME(nashenas88) RegionName
value: Region,
capped: bool,
}

impl VarDefinition {
pub fn new(value: Region) -> Self {
Self {
name: (),
value,
capped: false,
}
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Constraint {
sub: RegionIndex,
sup: RegionIndex,
point: Location,
}

newtype_index!(ConstraintIndex);

impl InferenceContext {
pub fn new(values: IndexVec<RegionIndex, Region>) -> Self {
Self {
definitions: values.into_iter().map(VarDefinition::new).collect(),
constraints: IndexVec::new(),
errors: IndexVec::new(),
}
}

#[allow(dead_code)]
pub fn cap_var(&mut self, v: RegionIndex) {
self.definitions[v].capped = true;
}

#[allow(dead_code)]
pub fn add_live_point(&mut self, v: RegionIndex, point: Location) {
debug!("add_live_point({:?}, {:?})", v, point);
let definition = &mut self.definitions[v];
if definition.value.add_point(point) {
if definition.capped {
self.errors.push(InferenceError {
constraint_point: point,
name: definition.name,
});
}
}
}

#[allow(dead_code)]
pub fn add_outlives(&mut self, sup: RegionIndex, sub: RegionIndex, point: Location) {
debug!("add_outlives({:?}: {:?} @ {:?}", sup, sub, point);
self.constraints.push(Constraint { sup, sub, point });
}

#[allow(dead_code)]
pub fn region(&self, v: RegionIndex) -> &Region {
&self.definitions[v].value
}

pub fn solve<'a, 'gcx, 'tcx>(
&mut self,
infcx: &'a InferCtxt<'a, 'gcx, 'tcx>,
mir: &'a Mir<'tcx>,
) -> IndexVec<InferenceErrorIndex, InferenceError>
where
'gcx: 'tcx + 'a,
'tcx: 'a,
{
let mut changed = true;
let mut dfs = Dfs::new(infcx, mir);
while changed {
changed = false;
for constraint in &self.constraints {
let sub = &self.definitions[constraint.sub].value.clone();
let sup_def = &mut self.definitions[constraint.sup];
debug!("constraint: {:?}", constraint);
debug!(" sub (before): {:?}", sub);
debug!(" sup (before): {:?}", sup_def.value);

if dfs.copy(sub, &mut sup_def.value, constraint.point) {
changed = true;
if sup_def.capped {
// This is kind of a hack, but when we add a
// constraint, the "point" is always the point
// AFTER the action that induced the
// constraint. So report the error on the
// action BEFORE that.
assert!(constraint.point.statement_index > 0);
let p = Location {
block: constraint.point.block,
statement_index: constraint.point.statement_index - 1,
};

self.errors.push(InferenceError {
constraint_point: p,
name: sup_def.name,
});
}
}

debug!(" sup (after) : {:?}", sup_def.value);
debug!(" changed : {:?}", changed);
}
debug!("\n");
}

mem::replace(&mut self.errors, IndexVec::new())
}
}

struct Dfs<'a, 'gcx: 'tcx + 'a, 'tcx: 'a> {
#[allow(dead_code)]
infcx: &'a InferCtxt<'a, 'gcx, 'tcx>,
mir: &'a Mir<'tcx>,
}

impl<'a, 'gcx: 'tcx, 'tcx: 'a> Dfs<'a, 'gcx, 'tcx> {
fn new(infcx: &'a InferCtxt<'a, 'gcx, 'tcx>, mir: &'a Mir<'tcx>) -> Self {
Self { infcx, mir }
}

fn copy(
&mut self,
from_region: &Region,
to_region: &mut Region,
start_point: Location,
) -> bool {
let mut changed = false;

let mut stack = vec![];
let mut visited = FxHashSet();

stack.push(start_point);
while let Some(p) = stack.pop() {
debug!(" dfs: p={:?}", p);

if !from_region.may_contain(p) {
debug!(" not in from-region");
continue;
}

if !visited.insert(p) {
debug!(" already visited");
continue;
}

changed |= to_region.add_point(p);

let block_data = &self.mir[p.block];
let successor_points = if p.statement_index < block_data.statements.len() {
vec![Location {
statement_index: p.statement_index + 1,
..p
}]
} else {
block_data.terminator()
.successors()
.iter()
.map(|&basic_block| Location {
statement_index: 0,
block: basic_block,
})
.collect::<Vec<_>>()
};

if successor_points.is_empty() {
// FIXME handle free regions
// If we reach the END point in the graph, then copy
// over any skolemized end points in the `from_region`
// and make sure they are included in the `to_region`.
// for region_decl in self.infcx.tcx.tables.borrow().free_region_map() {
// // FIXME(nashenas88) figure out skolemized_end points
// let block = self.env.graph.skolemized_end(region_decl.name);
// let skolemized_end_point = Location {
// block,
// statement_index: 0,
// };
// changed |= to_region.add_point(skolemized_end_point);
// }
} else {
stack.extend(successor_points);
}
}

changed
}
}
32 changes: 23 additions & 9 deletions src/librustc_mir/transform/nll/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use self::infer::InferenceContext;
use rustc::ty::TypeFoldable;
use rustc::ty::subst::{Kind, Substs};
use rustc::ty::{Ty, TyCtxt, ClosureSubsts, RegionVid, RegionKind};
use rustc::mir::{Mir, Location, Rvalue, BasicBlock, Statement, StatementKind};
use rustc::mir::visit::{MutVisitor, Lookup};
use rustc::mir::transform::{MirPass, MirSource};
use rustc::infer::{self, InferCtxt};
use rustc::infer::{self as rustc_infer, InferCtxt};
use rustc::util::nodemap::FxHashSet;
use rustc_data_structures::indexed_vec::{IndexVec, Idx};
use syntax_pos::DUMMY_SP;
Expand All @@ -24,30 +25,33 @@ use std::fmt;
use util as mir_util;
use self::mir_util::PassWhere;

mod infer;

#[allow(dead_code)]
struct NLLVisitor<'a, 'gcx: 'a + 'tcx, 'tcx: 'a> {
lookup_map: HashMap<RegionVid, Lookup>,
regions: IndexVec<RegionIndex, Region>,
infcx: InferCtxt<'a, 'gcx, 'tcx>,
#[allow(dead_code)]
infcx: &'a InferCtxt<'a, 'gcx, 'tcx>,
}

impl<'a, 'gcx, 'tcx> NLLVisitor<'a, 'gcx, 'tcx> {
pub fn new(infcx: InferCtxt<'a, 'gcx, 'tcx>) -> Self {
pub fn new(infcx: &'a InferCtxt<'a, 'gcx, 'tcx>) -> Self {
NLLVisitor {
infcx,
lookup_map: HashMap::new(),
regions: IndexVec::new(),
}
}

pub fn into_results(self) -> HashMap<RegionVid, Lookup> {
self.lookup_map
pub fn into_results(self) -> (HashMap<RegionVid, Lookup>, IndexVec<RegionIndex, Region>) {
(self.lookup_map, self.regions)
}

fn renumber_regions<T>(&mut self, value: &T) -> T where T: TypeFoldable<'tcx> {
self.infcx.tcx.fold_regions(value, &mut false, |_region, _depth| {
self.regions.push(Region::default());
self.infcx.next_region_var(infer::MiscVariable(DUMMY_SP))
self.infcx.next_region_var(rustc_infer::MiscVariable(DUMMY_SP))
})
}

Expand Down Expand Up @@ -147,7 +151,7 @@ impl MirPass for NLL {
tcx.infer_ctxt().enter(|infcx| {
// Clone mir so we can mutate it without disturbing the rest of the compiler
let mut renumbered_mir = mir.clone();
let mut visitor = NLLVisitor::new(infcx);
let mut visitor = NLLVisitor::new(&infcx);
visitor.visit_mir(&mut renumbered_mir);
mir_util::dump_mir(tcx, None, "nll", &0, source, mir, |pass_where, out| {
if let PassWhere::BeforeCFG = pass_where {
Expand All @@ -157,13 +161,15 @@ impl MirPass for NLL {
}
Ok(())
});
let _results = visitor.into_results();
let (_lookup_map, regions) = visitor.into_results();
let mut inference_context = InferenceContext::new(regions);
inference_context.solve(&infcx, &renumbered_mir);
})
}
}

#[derive(Clone, Default, PartialEq, Eq)]
struct Region {
pub struct Region {
points: FxHashSet<Location>,
}

Expand All @@ -173,6 +179,14 @@ impl fmt::Debug for Region {
}
}

impl Region {
pub fn add_point(&mut self, point: Location) -> bool {
self.points.insert(point)
}

pub fn may_contain(&self, point: Location) -> bool {
self.points.contains(&point)
}
}

newtype_index!(RegionIndex);