diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/NumberAndStringArithmeticOperationCoercion.java similarity index 65% rename from drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java rename to drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/NumberAndStringArithmeticOperationCoercion.java index dfd241615fe..8af6d878134 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/NumberAndStringArithmeticOperationCoercion.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.util.HashSet; +import java.util.Optional; import java.util.Set; import com.github.javaparser.ast.expr.BinaryExpr.Operator; @@ -27,6 +28,7 @@ import com.github.javaparser.ast.expr.NameExpr; import org.drools.model.codegen.execmodel.errors.InvalidExpressionErrorResult; import org.drools.model.codegen.execmodel.generator.TypedExpression; +import org.drools.util.Pair; import static com.github.javaparser.ast.NodeList.nodeList; import static com.github.javaparser.ast.expr.BinaryExpr.Operator.DIVIDE; @@ -37,35 +39,44 @@ import static java.util.Arrays.asList; import static org.drools.util.ClassUtils.isNumericClass; -public class ArithmeticCoercedExpression { +public final class NumberAndStringArithmeticOperationCoercion { - private final TypedExpression left; - private final TypedExpression right; - private final Operator operator; + private NumberAndStringArithmeticOperationCoercion() { + } private static final Set arithmeticOperators = new HashSet<>(asList(PLUS, MINUS, MULTIPLY, DIVIDE, REMAINDER)); - public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, Operator operator) { - this.left = left; - this.right = right; - this.operator = operator; + public static Pair coerceIfNeeded(final Operator operator, final TypedExpression left, final TypedExpression right) { + if (requiresCoercion(operator, left, right)) { + return coerce(operator, left, right); + } else { + return new Pair<>(null, null); + } + } + + public static boolean requiresCoercion(final Operator operator, final TypedExpression left, final TypedExpression right) { + if (!arithmeticOperators.contains(operator)) { + return false; + } + return canCoerce(left.getRawClass(), right.getRawClass()); + } + + private static boolean canCoerce(Class leftClass, Class rightClass) { + return leftClass == String.class && isNumericClass(rightClass) || + rightClass == String.class && isNumericClass(leftClass); } /* * This coercion only deals with String vs Numeric types. * BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall() */ - public ArithmeticCoercedExpressionResult coerce() { - - if (!requiresCoercion()) { - return new ArithmeticCoercedExpressionResult(left, right); // do not coerce - } + private static Pair coerce(final Operator operator, final TypedExpression left, final TypedExpression right) { final Class leftClass = left.getRawClass(); final Class rightClass = right.getRawClass(); if (!canCoerce(leftClass, rightClass)) { - throw new ArithmeticCoercedExpressionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass)); + throw new NumberAndStringArithmeticOperationCoercionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass)); } TypedExpression coercedLeft = left; @@ -89,65 +100,26 @@ public ArithmeticCoercedExpressionResult coerce() { } } - return new ArithmeticCoercedExpressionResult(coercedLeft, coercedRight); + return new Pair<>(coercedLeft, coercedRight); } - private boolean requiresCoercion() { - if (!arithmeticOperators.contains(operator)) { - return false; - } - final Class leftClass = left.getRawClass(); - final Class rightClass = right.getRawClass(); - if (leftClass == rightClass) { - return false; - } - if (isNumericClass(leftClass) && isNumericClass(rightClass)) { - return false; - } - return true; - } - - private boolean canCoerce(Class leftClass, Class rightClass) { - return leftClass == String.class && isNumericClass(rightClass) || - rightClass == String.class && isNumericClass(leftClass); - } - - private TypedExpression coerceToDouble(TypedExpression typedExpression) { + private static TypedExpression coerceToDouble(TypedExpression typedExpression) { final Expression expression = typedExpression.getExpression(); TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("Double"), "valueOf", nodeList(expression))); return coercedExpression.setType(BigDecimal.class); } - private TypedExpression coerceToString(TypedExpression typedExpression) { + private static TypedExpression coerceToString(TypedExpression typedExpression) { final Expression expression = typedExpression.getExpression(); TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("String"), "valueOf", nodeList(expression))); return coercedExpression.setType(String.class); } - public static class ArithmeticCoercedExpressionResult { - - private final TypedExpression coercedLeft; - private final TypedExpression coercedRight; - - public ArithmeticCoercedExpressionResult(TypedExpression left, TypedExpression coercedRight) { - this.coercedLeft = left; - this.coercedRight = coercedRight; - } - - public TypedExpression getCoercedLeft() { - return coercedLeft; - } - - public TypedExpression getCoercedRight() { - return coercedRight; - } - } - - public static class ArithmeticCoercedExpressionException extends RuntimeException { + public static class NumberAndStringArithmeticOperationCoercionException extends RuntimeException { private final transient InvalidExpressionErrorResult invalidExpressionErrorResult; - ArithmeticCoercedExpressionException(InvalidExpressionErrorResult invalidExpressionErrorResult) { + NumberAndStringArithmeticOperationCoercionException(InvalidExpressionErrorResult invalidExpressionErrorResult) { this.invalidExpressionErrorResult = invalidExpressionErrorResult; } diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java index ce4e8f91331..0a951354ca9 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java @@ -71,7 +71,7 @@ import org.drools.model.codegen.execmodel.generator.RuleContext; import org.drools.model.codegen.execmodel.generator.TypedExpression; import org.drools.model.codegen.execmodel.generator.UnificationTypedExpression; -import org.drools.model.codegen.execmodel.generator.drlxparse.ArithmeticCoercedExpression; +import org.drools.model.codegen.execmodel.generator.drlxparse.NumberAndStringArithmeticOperationCoercion; import org.drools.model.codegen.execmodel.generator.operatorspec.CustomOperatorSpec; import org.drools.model.codegen.execmodel.generator.operatorspec.NativeOperatorSpec; import org.drools.model.codegen.execmodel.generator.operatorspec.OperatorSpec; @@ -95,6 +95,7 @@ import org.drools.mvelcompiler.ConstraintCompiler; import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion; import org.drools.util.MethodUtils; +import org.drools.util.Pair; import org.drools.util.TypeResolver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -223,18 +224,16 @@ private Optional toTypedExpressionRec(Expression drlxExpr) { TypedExpression left = optLeft.get(); TypedExpression right = optRight.get(); - ArithmeticCoercedExpression.ArithmeticCoercedExpressionResult coerced; - try { - coerced = new ArithmeticCoercedExpression(left, right, operator).coerce(); - } catch (ArithmeticCoercedExpression.ArithmeticCoercedExpressionException e) { - logger.error("Failed to coerce : {}", e.getInvalidExpressionErrorResult()); - return empty(); + final BinaryExpr combo; + final Pair numberAndStringCoercionResult = + NumberAndStringArithmeticOperationCoercion.coerceIfNeeded(operator, left, right); + if (numberAndStringCoercionResult.hasLeft()) { + left = numberAndStringCoercionResult.getLeft(); } - - left = coerced.getCoercedLeft(); - right = coerced.getCoercedRight(); - - final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); + if (numberAndStringCoercionResult.hasRight()) { + right = numberAndStringCoercionResult.getRight(); + } + combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext); diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/ArithmeticCoecionTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/NumberAndStringArithmeticOperationCoercionTest.java similarity index 97% rename from drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/ArithmeticCoecionTest.java rename to drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/NumberAndStringArithmeticOperationCoercionTest.java index 8123d6ec26c..e7f11dbb95a 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/ArithmeticCoecionTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/NumberAndStringArithmeticOperationCoercionTest.java @@ -24,9 +24,9 @@ import static org.assertj.core.api.Assertions.assertThat; -public class ArithmeticCoecionTest extends BaseModelTest { +public class NumberAndStringArithmeticOperationCoercionTest extends BaseModelTest { - public ArithmeticCoecionTest(RUN_TYPE testRunType) { + public NumberAndStringArithmeticOperationCoercionTest(RUN_TYPE testRunType) { super(testRunType); } diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/TypeCoercionTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/TypeCoercionTest.java index 1e9921709b4..c69e249cbf0 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/TypeCoercionTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/TypeCoercionTest.java @@ -610,4 +610,30 @@ public void testFloatOperation() { assertThat(list.size()).isEqualTo(1); assertThat(list.get(0)).isEqualTo("Mario"); } + + @Test + public void testCoerceObjectToString() { + String str = "package constraintexpression\n" + + "\n" + + "import " + Person.class.getCanonicalName() + "\n" + + "import java.util.List; \n" + + "rule \"r1\"\n" + + "when \n" + + " $p: Person() \n" + + " String(this == \"someString\" + $p)\n" + + "then \n" + + " System.out.println($p); \n" + + "end \n"; + + KieSession ksession = getKieSession(str); + try { + Person person = new Person("someName"); + ksession.insert(person); + ksession.insert(new String("someStringsomeName")); + int rulesFired = ksession.fireAllRules(); + assertThat(rulesFired).isEqualTo(1); + } finally { + ksession.dispose(); + } + } } \ No newline at end of file diff --git a/drools-util/src/main/java/org/drools/util/Pair.java b/drools-util/src/main/java/org/drools/util/Pair.java new file mode 100644 index 00000000000..4408263db50 --- /dev/null +++ b/drools-util/src/main/java/org/drools/util/Pair.java @@ -0,0 +1,28 @@ +package org.drools.util; + +public class Pair { + + private final K left; + private final V right; + + public Pair(K k, V v) { + this.left = k; + this.right = v; + } + + public K getLeft() { + return left; + } + + public V getRight() { + return right; + } + + public boolean hasLeft() { + return left != null; + } + + public boolean hasRight() { + return right != null; + } +}