diff --git a/ExpressionUtils/PartialEvaluator.cs b/ExpressionUtils/PartialEvaluator.cs
index 9d9b03f..f542f64 100644
--- a/ExpressionUtils/PartialEvaluator.cs
+++ b/ExpressionUtils/PartialEvaluator.cs
@@ -20,29 +20,9 @@ public static Expression PartialEval(Expression expression, IExpressionEvaluator
///
/// The root of the expression tree.
/// A new tree with sub-trees evaluated and replaced.
- public static Expression PartialEval(Expression expression, IExpressionEvaluator evaluator)
- => PartialEval(expression, evaluator, canBeEvaluated);
-
- ///
- /// Performs evaluation & replacement of independent sub-trees in the body of .
- ///
- public static LambdaExpression PartialEval(LambdaExpression expression, IExpressionEvaluator evaluator)
- => (LambdaExpression)PartialEval(expression, evaluator, canBeEvaluated);
-
- ///
- /// Performs evaluation & replacement of independent sub-trees in the body
- /// of an typed .
- ///
- /// The lambda expression whichs body to
- /// partially evaluate.
- /// A new typed with sub-trees in the
- /// body evaluated and replaced.
- ///
- /// Call to is very important here.
- /// It allows expression to keep its original type, even if its body was replaced with call.
- ///
- public static Expression PartialEval(Expression expFunc, IExpressionEvaluator evaluator)
- => expFunc.Update(PartialEval(expFunc.Body, evaluator), expFunc.Parameters);
+ public static TExpression PartialEval(TExpression expression, IExpressionEvaluator evaluator) where TExpression : Expression {
+ return (TExpression)PartialEval(expression, evaluator, canBeEvaluated);
+ }
private static bool canBeEvaluated(Expression expression) {
if (expression.NodeType == ExpressionType.Parameter) {
@@ -72,22 +52,47 @@ internal Expression Eval(Expression exp) {
return this.Visit(exp);
}
+ private int depth = -1;
+
public override Expression Visit(Expression exp) {
if (exp == null) {
return null;
}
- if (candidates.Contains(exp)) {
+
+ // In case we visit lambda expression, we want to return lambda expression as a result of visit,
+ // so we don't want to replace root lambda node with constant expression node, so we don't do anything here.
+ if (candidates.Contains(exp) && !(depth == -1 && exp is LambdaExpression)) {
if (exp is ConstantExpression) {
return exp;
}
- try {
- return Expression.Constant(evaluator.Evaluate(exp), exp.Type);
- } catch (Exception exception) {
- return ExceptionClosure.MakeExceptionClosureCall(exception, exp.Type);
- }
+ return evaluate(exp);
+ }
+
+ depth++;
+ var newNode = base.Visit(exp);
+ depth--;
+ return newNode;
+ }
+
+ protected override Expression VisitLambda(Expression node) {
+ // This is root lambda node that we want to evaluate. Since we want to preserve type
+ // of input expression, we will update the body, but still keep the lambda.
+ if (candidates.Contains(node) && depth == 0) {
+ var constant = evaluate(node.Body);
+ return node.Update(constant, node.Parameters);
+ }
+
+ return base.VisitLambda(node);
+ }
+
+ private Expression evaluate(Expression exp) {
+ try {
+ return Expression.Constant(evaluator.Evaluate(exp), exp.Type);
+ }
+ catch (Exception exception) {
+ return ExceptionClosure.MakeExceptionClosureCall(exception, exp.Type);
}
- return base.Visit(exp);
}
}
diff --git a/README.md b/README.md
index 6110661..045b531 100644
--- a/README.md
+++ b/README.md
@@ -56,7 +56,7 @@ int modul = 0;
Expression> modulExpression = x => x % modul == 0;
modul = 2;
-var isMultipleOfTwoExpression = PartialEvaluator.PartialEvalBody(
+var isMultipleOfTwoExpression = PartialEvaluator.PartialEval(
modulExpression,
ExpressionInterpreter.Instance);
@@ -66,7 +66,7 @@ Console.Write(isEvenExpression.StructuralIdentical(modulExpression)); // false
Console.Write(isEvenExpression.StructuralIdentical(isMultipleOfTwoExpression)); // true
modul = 3;
-var isMultipleOfThreeExpression = PartialEvaluator.PartialEvalBody(
+var isMultipleOfThreeExpression = PartialEvaluator.PartialEval(
modulExpression,
ExpressionInterpreter.Instance);