Skip to content

Support constants in const eval #11772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub mod symbols;

mod display;

use std::{collections::HashMap, iter, ops::ControlFlow, sync::Arc};
use std::{iter, ops::ControlFlow, sync::Arc};

use arrayvec::ArrayVec;
use base_db::{CrateDisplayName, CrateId, CrateOrigin, Edition, FileId, ProcMacroKind};
Expand All @@ -55,9 +55,7 @@ use hir_def::{
use hir_expand::{name::name, MacroCallKind};
use hir_ty::{
autoderef,
consteval::{
eval_const, unknown_const_as_generic, ComputedExpr, ConstEvalCtx, ConstEvalError, ConstExt,
},
consteval::{unknown_const_as_generic, ComputedExpr, ConstEvalError, ConstExt},
diagnostics::BodyValidationDiagnostic,
method_resolution::{self, TyFingerprint},
primitive::UintTy,
Expand Down Expand Up @@ -1602,20 +1600,7 @@ impl Const {
}

pub fn eval(self, db: &dyn HirDatabase) -> Result<ComputedExpr, ConstEvalError> {
let body = db.body(self.id.into());
let root = &body.exprs[body.body_expr];
let infer = db.infer_query(self.id.into());
let infer = infer.as_ref();
let result = eval_const(
root,
&mut ConstEvalCtx {
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer: &mut |x| infer[x].clone(),
},
);
result
db.const_eval(self.id)
}
}

Expand Down
162 changes: 118 additions & 44 deletions crates/hir_ty/src/consteval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,19 @@ use std::{

use chalk_ir::{BoundVar, DebruijnIndex, GenericArgData, IntTy, Scalar};
use hir_def::{
expr::{ArithOp, BinaryOp, Expr, Literal, Pat},
expr::{ArithOp, BinaryOp, Expr, ExprId, Literal, Pat, PatId},
path::ModPath,
resolver::{Resolver, ValueNs},
resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs},
type_ref::ConstScalar,
ConstId, DefWithBodyId,
};
use hir_expand::name::Name;
use la_arena::{Arena, Idx};
use stdx::never;

use crate::{
db::HirDatabase,
infer::{Expectation, InferenceContext},
lower::ParamLoweringMode,
to_placeholder_idx,
utils::Generics,
Const, ConstData, ConstValue, GenericArg, Interner, Ty, TyKind,
db::HirDatabase, infer::InferenceContext, lower::ParamLoweringMode, to_placeholder_idx,
utils::Generics, Const, ConstData, ConstValue, GenericArg, InferenceResult, Interner, Ty,
TyKind,
};

/// Extension trait for [`Const`]
Expand Down Expand Up @@ -55,21 +52,30 @@ impl ConstExt for Const {
}

pub struct ConstEvalCtx<'a> {
pub db: &'a dyn HirDatabase,
pub owner: DefWithBodyId,
pub exprs: &'a Arena<Expr>,
pub pats: &'a Arena<Pat>,
pub local_data: HashMap<Name, ComputedExpr>,
pub infer: &'a mut dyn FnMut(Idx<Expr>) -> Ty,
pub local_data: HashMap<PatId, ComputedExpr>,
infer: &'a InferenceResult,
}

#[derive(Debug, Clone)]
impl ConstEvalCtx<'_> {
fn expr_ty(&mut self, expr: ExprId) -> Ty {
self.infer[expr].clone()
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConstEvalError {
NotSupported(&'static str),
TypeError,
SemanticError(&'static str),
Loop,
IncompleteExpr,
Panic(String),
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ComputedExpr {
Literal(Literal),
Tuple(Box<[ComputedExpr]>),
Expand All @@ -80,14 +86,14 @@ impl Display for ComputedExpr {
match self {
ComputedExpr::Literal(l) => match l {
Literal::Int(x, _) => {
if *x >= 16 {
if *x >= 10 {
write!(f, "{} ({:#X})", x, x)
} else {
x.fmt(f)
}
}
Literal::Uint(x, _) => {
if *x >= 16 {
if *x >= 10 {
write!(f, "{} ({:#X})", x, x)
} else {
x.fmt(f)
Expand Down Expand Up @@ -143,12 +149,17 @@ fn is_valid(scalar: &Scalar, value: i128) -> bool {
}
}

pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExpr, ConstEvalError> {
pub fn eval_const(
expr_id: ExprId,
ctx: &mut ConstEvalCtx<'_>,
) -> Result<ComputedExpr, ConstEvalError> {
let expr = &ctx.exprs[expr_id];
match expr {
Expr::Missing => Err(ConstEvalError::IncompleteExpr),
Expr::Literal(l) => Ok(ComputedExpr::Literal(l.clone())),
&Expr::UnaryOp { expr, op } => {
let ty = &(ctx.infer)(expr);
let ev = eval_const(&ctx.exprs[expr], ctx)?;
let ty = &ctx.expr_ty(expr);
let ev = eval_const(expr, ctx)?;
match op {
hir_def::expr::UnaryOp::Deref => Err(ConstEvalError::NotSupported("deref")),
hir_def::expr::UnaryOp::Not => {
Expand Down Expand Up @@ -203,9 +214,9 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
}
&Expr::BinaryOp { lhs, rhs, op } => {
let ty = &(ctx.infer)(lhs);
let lhs = eval_const(&ctx.exprs[lhs], ctx)?;
let rhs = eval_const(&ctx.exprs[rhs], ctx)?;
let ty = &ctx.expr_ty(lhs);
let lhs = eval_const(lhs, ctx)?;
let rhs = eval_const(rhs, ctx)?;
let op = op.ok_or(ConstEvalError::IncompleteExpr)?;
let v1 = match lhs {
ComputedExpr::Literal(Literal::Int(v, _)) => v,
Expand Down Expand Up @@ -249,31 +260,31 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
Ok(ComputedExpr::Literal(Literal::Int(r, None)))
}
BinaryOp::LogicOp(_) => Err(ConstEvalError::TypeError),
BinaryOp::LogicOp(_) => Err(ConstEvalError::SemanticError("logic op on numbers")),
_ => Err(ConstEvalError::NotSupported("bin op on this operators")),
}
}
Expr::Block { statements, tail, .. } => {
let mut prev_values = HashMap::<Name, Option<ComputedExpr>>::default();
let mut prev_values = HashMap::<PatId, Option<ComputedExpr>>::default();
for statement in &**statements {
match *statement {
hir_def::expr::Statement::Let { pat, initializer, .. } => {
let pat = &ctx.pats[pat];
let name = match pat {
Pat::Bind { name, subpat, .. } if subpat.is_none() => name.clone(),
hir_def::expr::Statement::Let { pat: pat_id, initializer, .. } => {
let pat = &ctx.pats[pat_id];
match pat {
Pat::Bind { subpat, .. } if subpat.is_none() => (),
_ => {
return Err(ConstEvalError::NotSupported("complex patterns in let"))
}
};
let value = match initializer {
Some(x) => eval_const(&ctx.exprs[x], ctx)?,
Some(x) => eval_const(x, ctx)?,
None => continue,
};
if !prev_values.contains_key(&name) {
let prev = ctx.local_data.insert(name.clone(), value);
prev_values.insert(name, prev);
if !prev_values.contains_key(&pat_id) {
let prev = ctx.local_data.insert(pat_id, value);
prev_values.insert(pat_id, prev);
} else {
ctx.local_data.insert(name, value);
ctx.local_data.insert(pat_id, value);
}
}
hir_def::expr::Statement::Expr { .. } => {
Expand All @@ -282,7 +293,7 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
}
let r = match tail {
&Some(x) => eval_const(&ctx.exprs[x], ctx),
&Some(x) => eval_const(x, ctx),
None => Ok(ComputedExpr::Tuple(Box::new([]))),
};
// clean up local data, so caller will receive the exact map that passed to us
Expand All @@ -295,19 +306,48 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
r
}
Expr::Path(p) => {
let name = p.mod_path().as_ident().ok_or(ConstEvalError::NotSupported("big paths"))?;
let r = ctx
.local_data
.get(name)
.ok_or(ConstEvalError::NotSupported("Non local name resolution"))?;
Ok(r.clone())
let resolver = resolver_for_expr(ctx.db.upcast(), ctx.owner, expr_id);
let pr = resolver
.resolve_path_in_value_ns(ctx.db.upcast(), p.mod_path())
.ok_or(ConstEvalError::SemanticError("unresolved path"))?;
let pr = match pr {
ResolveValueResult::ValueNs(v) => v,
ResolveValueResult::Partial(..) => {
return match ctx
.infer
.assoc_resolutions_for_expr(expr_id)
.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))?
{
hir_def::AssocItemId::FunctionId(_) => {
Err(ConstEvalError::NotSupported("assoc function"))
}
hir_def::AssocItemId::ConstId(c) => ctx.db.const_eval(c),
hir_def::AssocItemId::TypeAliasId(_) => {
Err(ConstEvalError::NotSupported("assoc type alias"))
}
}
}
};
match pr {
ValueNs::LocalBinding(pat_id) => {
let r = ctx
.local_data
.get(&pat_id)
.ok_or(ConstEvalError::NotSupported("Unexpected missing local"))?;
Ok(r.clone())
}
ValueNs::ConstId(id) => ctx.db.const_eval(id),
ValueNs::GenericParam(_) => {
Err(ConstEvalError::NotSupported("const generic without substitution"))
}
_ => Err(ConstEvalError::NotSupported("path that are not const or local")),
}
}
_ => Err(ConstEvalError::NotSupported("This kind of expression")),
}
}

pub fn eval_usize(expr: Idx<Expr>, mut ctx: ConstEvalCtx<'_>) -> Option<u64> {
let expr = &ctx.exprs[expr];
if let Ok(ce) = eval_const(expr, &mut ctx) {
match ce {
ComputedExpr::Literal(Literal::Int(x, _)) => return x.try_into().ok(),
Expand Down Expand Up @@ -380,10 +420,39 @@ pub fn usize_const(value: Option<u64>) -> Const {
.intern(Interner)
}

pub(crate) fn eval_to_const(
pub(crate) fn const_eval_recover(
_: &dyn HirDatabase,
_: &[String],
_: &ConstId,
) -> Result<ComputedExpr, ConstEvalError> {
Err(ConstEvalError::Loop)
}

pub(crate) fn const_eval_query(
db: &dyn HirDatabase,
const_id: ConstId,
) -> Result<ComputedExpr, ConstEvalError> {
let def = const_id.into();
let body = db.body(def);
let infer = &db.infer(def);
let result = eval_const(
body.body_expr,
&mut ConstEvalCtx {
db,
owner: const_id.into(),
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer,
},
);
result
}

pub(crate) fn eval_to_const<'a>(
expr: Idx<Expr>,
mode: ParamLoweringMode,
ctx: &mut InferenceContext,
ctx: &mut InferenceContext<'a>,
args: impl FnOnce() -> Generics,
debruijn: DebruijnIndex,
) -> Const {
Expand All @@ -396,10 +465,15 @@ pub(crate) fn eval_to_const(
}
let body = ctx.body.clone();
let ctx = ConstEvalCtx {
db: ctx.db,
owner: ctx.owner,
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer: &mut |x| ctx.infer_expr(x, &Expectation::None),
infer: &ctx.result,
};
usize_const(eval_usize(expr, ctx))
}

#[cfg(test)]
mod tests;
Loading