diff --git a/ExpressionUtils/PartialEvaluator.cs b/ExpressionUtils/PartialEvaluator.cs index 9d9b03f..f542f64 100644 --- a/ExpressionUtils/PartialEvaluator.cs +++ b/ExpressionUtils/PartialEvaluator.cs @@ -20,29 +20,9 @@ public static Expression PartialEval(Expression expression, IExpressionEvaluator /// /// The root of the expression tree. /// A new tree with sub-trees evaluated and replaced. - public static Expression PartialEval(Expression expression, IExpressionEvaluator evaluator) - => PartialEval(expression, evaluator, canBeEvaluated); - - /// - /// Performs evaluation & replacement of independent sub-trees in the body of . - /// - public static LambdaExpression PartialEval(LambdaExpression expression, IExpressionEvaluator evaluator) - => (LambdaExpression)PartialEval(expression, evaluator, canBeEvaluated); - - /// - /// Performs evaluation & replacement of independent sub-trees in the body - /// of an typed . - /// - /// The lambda expression whichs body to - /// partially evaluate. - /// A new typed with sub-trees in the - /// body evaluated and replaced. - /// - /// Call to is very important here. - /// It allows expression to keep its original type, even if its body was replaced with call. - /// - public static Expression PartialEval(Expression expFunc, IExpressionEvaluator evaluator) - => expFunc.Update(PartialEval(expFunc.Body, evaluator), expFunc.Parameters); + public static TExpression PartialEval(TExpression expression, IExpressionEvaluator evaluator) where TExpression : Expression { + return (TExpression)PartialEval(expression, evaluator, canBeEvaluated); + } private static bool canBeEvaluated(Expression expression) { if (expression.NodeType == ExpressionType.Parameter) { @@ -72,22 +52,47 @@ internal Expression Eval(Expression exp) { return this.Visit(exp); } + private int depth = -1; + public override Expression Visit(Expression exp) { if (exp == null) { return null; } - if (candidates.Contains(exp)) { + + // In case we visit lambda expression, we want to return lambda expression as a result of visit, + // so we don't want to replace root lambda node with constant expression node, so we don't do anything here. + if (candidates.Contains(exp) && !(depth == -1 && exp is LambdaExpression)) { if (exp is ConstantExpression) { return exp; } - try { - return Expression.Constant(evaluator.Evaluate(exp), exp.Type); - } catch (Exception exception) { - return ExceptionClosure.MakeExceptionClosureCall(exception, exp.Type); - } + return evaluate(exp); + } + + depth++; + var newNode = base.Visit(exp); + depth--; + return newNode; + } + + protected override Expression VisitLambda(Expression node) { + // This is root lambda node that we want to evaluate. Since we want to preserve type + // of input expression, we will update the body, but still keep the lambda. + if (candidates.Contains(node) && depth == 0) { + var constant = evaluate(node.Body); + return node.Update(constant, node.Parameters); + } + + return base.VisitLambda(node); + } + + private Expression evaluate(Expression exp) { + try { + return Expression.Constant(evaluator.Evaluate(exp), exp.Type); + } + catch (Exception exception) { + return ExceptionClosure.MakeExceptionClosureCall(exception, exp.Type); } - return base.Visit(exp); } } diff --git a/README.md b/README.md index 6110661..045b531 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ int modul = 0; Expression> modulExpression = x => x % modul == 0; modul = 2; -var isMultipleOfTwoExpression = PartialEvaluator.PartialEvalBody( +var isMultipleOfTwoExpression = PartialEvaluator.PartialEval( modulExpression, ExpressionInterpreter.Instance); @@ -66,7 +66,7 @@ Console.Write(isEvenExpression.StructuralIdentical(modulExpression)); // false Console.Write(isEvenExpression.StructuralIdentical(isMultipleOfTwoExpression)); // true modul = 3; -var isMultipleOfThreeExpression = PartialEvaluator.PartialEvalBody( +var isMultipleOfThreeExpression = PartialEvaluator.PartialEval( modulExpression, ExpressionInterpreter.Instance);