Skip to content

Commit

Permalink
[kie-issues#986] Coerce object to String in executable model codegen (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
baldimir authored Mar 7, 2024
1 parent 342dcb2 commit f4ef82c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.math.BigDecimal;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import com.github.javaparser.ast.expr.BinaryExpr.Operator;
Expand All @@ -27,6 +28,7 @@
import com.github.javaparser.ast.expr.NameExpr;
import org.drools.model.codegen.execmodel.errors.InvalidExpressionErrorResult;
import org.drools.model.codegen.execmodel.generator.TypedExpression;
import org.drools.util.Pair;

import static com.github.javaparser.ast.NodeList.nodeList;
import static com.github.javaparser.ast.expr.BinaryExpr.Operator.DIVIDE;
Expand All @@ -37,35 +39,44 @@
import static java.util.Arrays.asList;
import static org.drools.util.ClassUtils.isNumericClass;

public class ArithmeticCoercedExpression {
public final class NumberAndStringArithmeticOperationCoercion {

private final TypedExpression left;
private final TypedExpression right;
private final Operator operator;
private NumberAndStringArithmeticOperationCoercion() {
}

private static final Set<Operator> arithmeticOperators = new HashSet<>(asList(PLUS, MINUS, MULTIPLY, DIVIDE, REMAINDER));

public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, Operator operator) {
this.left = left;
this.right = right;
this.operator = operator;
public static Pair<TypedExpression, TypedExpression> coerceIfNeeded(final Operator operator, final TypedExpression left, final TypedExpression right) {
if (requiresCoercion(operator, left, right)) {
return coerce(operator, left, right);
} else {
return new Pair<>(null, null);
}
}

public static boolean requiresCoercion(final Operator operator, final TypedExpression left, final TypedExpression right) {
if (!arithmeticOperators.contains(operator)) {
return false;
}
return canCoerce(left.getRawClass(), right.getRawClass());
}

private static boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
return leftClass == String.class && isNumericClass(rightClass) ||
rightClass == String.class && isNumericClass(leftClass);
}

/*
* This coercion only deals with String vs Numeric types.
* BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall()
*/
public ArithmeticCoercedExpressionResult coerce() {

if (!requiresCoercion()) {
return new ArithmeticCoercedExpressionResult(left, right); // do not coerce
}
private static Pair<TypedExpression, TypedExpression> coerce(final Operator operator, final TypedExpression left, final TypedExpression right) {

final Class<?> leftClass = left.getRawClass();
final Class<?> rightClass = right.getRawClass();

if (!canCoerce(leftClass, rightClass)) {
throw new ArithmeticCoercedExpressionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass));
throw new NumberAndStringArithmeticOperationCoercionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass));
}

TypedExpression coercedLeft = left;
Expand All @@ -89,65 +100,26 @@ public ArithmeticCoercedExpressionResult coerce() {
}
}

return new ArithmeticCoercedExpressionResult(coercedLeft, coercedRight);
return new Pair<>(coercedLeft, coercedRight);
}

private boolean requiresCoercion() {
if (!arithmeticOperators.contains(operator)) {
return false;
}
final Class<?> leftClass = left.getRawClass();
final Class<?> rightClass = right.getRawClass();
if (leftClass == rightClass) {
return false;
}
if (isNumericClass(leftClass) && isNumericClass(rightClass)) {
return false;
}
return true;
}

private boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
return leftClass == String.class && isNumericClass(rightClass) ||
rightClass == String.class && isNumericClass(leftClass);
}

private TypedExpression coerceToDouble(TypedExpression typedExpression) {
private static TypedExpression coerceToDouble(TypedExpression typedExpression) {
final Expression expression = typedExpression.getExpression();
TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("Double"), "valueOf", nodeList(expression)));
return coercedExpression.setType(BigDecimal.class);
}

private TypedExpression coerceToString(TypedExpression typedExpression) {
private static TypedExpression coerceToString(TypedExpression typedExpression) {
final Expression expression = typedExpression.getExpression();
TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("String"), "valueOf", nodeList(expression)));
return coercedExpression.setType(String.class);
}

public static class ArithmeticCoercedExpressionResult {

private final TypedExpression coercedLeft;
private final TypedExpression coercedRight;

public ArithmeticCoercedExpressionResult(TypedExpression left, TypedExpression coercedRight) {
this.coercedLeft = left;
this.coercedRight = coercedRight;
}

public TypedExpression getCoercedLeft() {
return coercedLeft;
}

public TypedExpression getCoercedRight() {
return coercedRight;
}
}

public static class ArithmeticCoercedExpressionException extends RuntimeException {
public static class NumberAndStringArithmeticOperationCoercionException extends RuntimeException {

private final transient InvalidExpressionErrorResult invalidExpressionErrorResult;

ArithmeticCoercedExpressionException(InvalidExpressionErrorResult invalidExpressionErrorResult) {
NumberAndStringArithmeticOperationCoercionException(InvalidExpressionErrorResult invalidExpressionErrorResult) {
this.invalidExpressionErrorResult = invalidExpressionErrorResult;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import org.drools.model.codegen.execmodel.generator.RuleContext;
import org.drools.model.codegen.execmodel.generator.TypedExpression;
import org.drools.model.codegen.execmodel.generator.UnificationTypedExpression;
import org.drools.model.codegen.execmodel.generator.drlxparse.ArithmeticCoercedExpression;
import org.drools.model.codegen.execmodel.generator.drlxparse.NumberAndStringArithmeticOperationCoercion;
import org.drools.model.codegen.execmodel.generator.operatorspec.CustomOperatorSpec;
import org.drools.model.codegen.execmodel.generator.operatorspec.NativeOperatorSpec;
import org.drools.model.codegen.execmodel.generator.operatorspec.OperatorSpec;
Expand All @@ -95,6 +95,7 @@
import org.drools.mvelcompiler.ConstraintCompiler;
import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion;
import org.drools.util.MethodUtils;
import org.drools.util.Pair;
import org.drools.util.TypeResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -223,18 +224,16 @@ private Optional<TypedExpression> toTypedExpressionRec(Expression drlxExpr) {
TypedExpression left = optLeft.get();
TypedExpression right = optRight.get();

ArithmeticCoercedExpression.ArithmeticCoercedExpressionResult coerced;
try {
coerced = new ArithmeticCoercedExpression(left, right, operator).coerce();
} catch (ArithmeticCoercedExpression.ArithmeticCoercedExpressionException e) {
logger.error("Failed to coerce : {}", e.getInvalidExpressionErrorResult());
return empty();
final BinaryExpr combo;
final Pair<TypedExpression, TypedExpression> numberAndStringCoercionResult =
NumberAndStringArithmeticOperationCoercion.coerceIfNeeded(operator, left, right);
if (numberAndStringCoercionResult.hasLeft()) {
left = numberAndStringCoercionResult.getLeft();
}

left = coerced.getCoercedLeft();
right = coerced.getCoercedRight();

final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);
if (numberAndStringCoercionResult.hasRight()) {
right = numberAndStringCoercionResult.getRight();
}
combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);

if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import static org.assertj.core.api.Assertions.assertThat;

public class ArithmeticCoecionTest extends BaseModelTest {
public class NumberAndStringArithmeticOperationCoercionTest extends BaseModelTest {

public ArithmeticCoecionTest(RUN_TYPE testRunType) {
public NumberAndStringArithmeticOperationCoercionTest(RUN_TYPE testRunType) {
super(testRunType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,30 @@ public void testFloatOperation() {
assertThat(list.size()).isEqualTo(1);
assertThat(list.get(0)).isEqualTo("Mario");
}

@Test
public void testCoerceObjectToString() {
String str = "package constraintexpression\n" +
"\n" +
"import " + Person.class.getCanonicalName() + "\n" +
"import java.util.List; \n" +
"rule \"r1\"\n" +
"when \n" +
" $p: Person() \n" +
" String(this == \"someString\" + $p)\n" +
"then \n" +
" System.out.println($p); \n" +
"end \n";

KieSession ksession = getKieSession(str);
try {
Person person = new Person("someName");
ksession.insert(person);
ksession.insert(new String("someStringsomeName"));
int rulesFired = ksession.fireAllRules();
assertThat(rulesFired).isEqualTo(1);
} finally {
ksession.dispose();
}
}
}
28 changes: 28 additions & 0 deletions drools-util/src/main/java/org/drools/util/Pair.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.drools.util;

public class Pair<K, V> {

private final K left;
private final V right;

public Pair(K k, V v) {
this.left = k;
this.right = v;
}

public K getLeft() {
return left;
}

public V getRight() {
return right;
}

public boolean hasLeft() {
return left != null;
}

public boolean hasRight() {
return right != null;
}
}

0 comments on commit f4ef82c

Please sign in to comment.