Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kie-issues#986] Coerce object to String in executable model codegen #5769

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.math.BigDecimal;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import com.github.javaparser.ast.expr.BinaryExpr.Operator;
Expand All @@ -27,6 +28,7 @@
import com.github.javaparser.ast.expr.NameExpr;
import org.drools.model.codegen.execmodel.errors.InvalidExpressionErrorResult;
import org.drools.model.codegen.execmodel.generator.TypedExpression;
import org.drools.util.Pair;

import static com.github.javaparser.ast.NodeList.nodeList;
import static com.github.javaparser.ast.expr.BinaryExpr.Operator.DIVIDE;
Expand All @@ -37,35 +39,44 @@
import static java.util.Arrays.asList;
import static org.drools.util.ClassUtils.isNumericClass;

public class ArithmeticCoercedExpression {
public final class NumberAndStringArithmeticOperationCoercion {

private final TypedExpression left;
private final TypedExpression right;
private final Operator operator;
private NumberAndStringArithmeticOperationCoercion() {
}

private static final Set<Operator> arithmeticOperators = new HashSet<>(asList(PLUS, MINUS, MULTIPLY, DIVIDE, REMAINDER));

public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, Operator operator) {
this.left = left;
this.right = right;
this.operator = operator;
public static Pair<TypedExpression, TypedExpression> coerceIfNeeded(final Operator operator, final TypedExpression left, final TypedExpression right) {
if (requiresCoercion(operator, left, right)) {
return coerce(operator, left, right);
} else {
return new Pair<>(null, null);
}
}

public static boolean requiresCoercion(final Operator operator, final TypedExpression left, final TypedExpression right) {
if (!arithmeticOperators.contains(operator)) {
return false;
}
return canCoerce(left.getRawClass(), right.getRawClass());
}

private static boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
return leftClass == String.class && isNumericClass(rightClass) ||
rightClass == String.class && isNumericClass(leftClass);
}

/*
* This coercion only deals with String vs Numeric types.
* BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall()
*/
public ArithmeticCoercedExpressionResult coerce() {

if (!requiresCoercion()) {
return new ArithmeticCoercedExpressionResult(left, right); // do not coerce
}
private static Pair<TypedExpression, TypedExpression> coerce(final Operator operator, final TypedExpression left, final TypedExpression right) {

final Class<?> leftClass = left.getRawClass();
final Class<?> rightClass = right.getRawClass();

if (!canCoerce(leftClass, rightClass)) {
throw new ArithmeticCoercedExpressionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass));
throw new NumberAndStringArithmeticOperationCoercionException(new InvalidExpressionErrorResult("Arithmetic operation requires compatible types. Found " + leftClass + " and " + rightClass));
}

TypedExpression coercedLeft = left;
Expand All @@ -89,65 +100,26 @@ public ArithmeticCoercedExpressionResult coerce() {
}
}

return new ArithmeticCoercedExpressionResult(coercedLeft, coercedRight);
return new Pair<>(coercedLeft, coercedRight);
}

private boolean requiresCoercion() {
if (!arithmeticOperators.contains(operator)) {
return false;
}
final Class<?> leftClass = left.getRawClass();
final Class<?> rightClass = right.getRawClass();
if (leftClass == rightClass) {
return false;
}
if (isNumericClass(leftClass) && isNumericClass(rightClass)) {
return false;
}
return true;
}

private boolean canCoerce(Class<?> leftClass, Class<?> rightClass) {
return leftClass == String.class && isNumericClass(rightClass) ||
rightClass == String.class && isNumericClass(leftClass);
}

private TypedExpression coerceToDouble(TypedExpression typedExpression) {
private static TypedExpression coerceToDouble(TypedExpression typedExpression) {
final Expression expression = typedExpression.getExpression();
TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("Double"), "valueOf", nodeList(expression)));
return coercedExpression.setType(BigDecimal.class);
}

private TypedExpression coerceToString(TypedExpression typedExpression) {
private static TypedExpression coerceToString(TypedExpression typedExpression) {
final Expression expression = typedExpression.getExpression();
TypedExpression coercedExpression = typedExpression.cloneWithNewExpression(new MethodCallExpr(new NameExpr("String"), "valueOf", nodeList(expression)));
return coercedExpression.setType(String.class);
}

public static class ArithmeticCoercedExpressionResult {

private final TypedExpression coercedLeft;
private final TypedExpression coercedRight;

public ArithmeticCoercedExpressionResult(TypedExpression left, TypedExpression coercedRight) {
this.coercedLeft = left;
this.coercedRight = coercedRight;
}

public TypedExpression getCoercedLeft() {
return coercedLeft;
}

public TypedExpression getCoercedRight() {
return coercedRight;
}
}

public static class ArithmeticCoercedExpressionException extends RuntimeException {
public static class NumberAndStringArithmeticOperationCoercionException extends RuntimeException {

private final transient InvalidExpressionErrorResult invalidExpressionErrorResult;

ArithmeticCoercedExpressionException(InvalidExpressionErrorResult invalidExpressionErrorResult) {
NumberAndStringArithmeticOperationCoercionException(InvalidExpressionErrorResult invalidExpressionErrorResult) {
this.invalidExpressionErrorResult = invalidExpressionErrorResult;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import org.drools.model.codegen.execmodel.generator.RuleContext;
import org.drools.model.codegen.execmodel.generator.TypedExpression;
import org.drools.model.codegen.execmodel.generator.UnificationTypedExpression;
import org.drools.model.codegen.execmodel.generator.drlxparse.ArithmeticCoercedExpression;
import org.drools.model.codegen.execmodel.generator.drlxparse.NumberAndStringArithmeticOperationCoercion;
import org.drools.model.codegen.execmodel.generator.operatorspec.CustomOperatorSpec;
import org.drools.model.codegen.execmodel.generator.operatorspec.NativeOperatorSpec;
import org.drools.model.codegen.execmodel.generator.operatorspec.OperatorSpec;
Expand All @@ -95,6 +95,7 @@
import org.drools.mvelcompiler.ConstraintCompiler;
import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion;
import org.drools.util.MethodUtils;
import org.drools.util.Pair;
import org.drools.util.TypeResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -223,18 +224,16 @@ private Optional<TypedExpression> toTypedExpressionRec(Expression drlxExpr) {
TypedExpression left = optLeft.get();
TypedExpression right = optRight.get();

ArithmeticCoercedExpression.ArithmeticCoercedExpressionResult coerced;
try {
coerced = new ArithmeticCoercedExpression(left, right, operator).coerce();
} catch (ArithmeticCoercedExpression.ArithmeticCoercedExpressionException e) {
logger.error("Failed to coerce : {}", e.getInvalidExpressionErrorResult());
return empty();
final BinaryExpr combo;
final Pair<TypedExpression, TypedExpression> numberAndStringCoercionResult =
NumberAndStringArithmeticOperationCoercion.coerceIfNeeded(operator, left, right);
if (numberAndStringCoercionResult.hasLeft()) {
left = numberAndStringCoercionResult.getLeft();
}

left = coerced.getCoercedLeft();
right = coerced.getCoercedRight();

final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);
if (numberAndStringCoercionResult.hasRight()) {
right = numberAndStringCoercionResult.getRight();
}
combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);

if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import static org.assertj.core.api.Assertions.assertThat;

public class ArithmeticCoecionTest extends BaseModelTest {
public class NumberAndStringArithmeticOperationCoercionTest extends BaseModelTest {

public ArithmeticCoecionTest(RUN_TYPE testRunType) {
public NumberAndStringArithmeticOperationCoercionTest(RUN_TYPE testRunType) {
super(testRunType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,30 @@ public void testFloatOperation() {
assertThat(list.size()).isEqualTo(1);
assertThat(list.get(0)).isEqualTo("Mario");
}

@Test
public void testCoerceObjectToString() {
String str = "package constraintexpression\n" +
"\n" +
"import " + Person.class.getCanonicalName() + "\n" +
"import java.util.List; \n" +
"rule \"r1\"\n" +
"when \n" +
" $p: Person() \n" +
" String(this == \"someString\" + $p)\n" +
"then \n" +
" System.out.println($p); \n" +
"end \n";

KieSession ksession = getKieSession(str);
try {
Person person = new Person("someName");
ksession.insert(person);
ksession.insert(new String("someStringsomeName"));
int rulesFired = ksession.fireAllRules();
assertThat(rulesFired).isEqualTo(1);
} finally {
ksession.dispose();
}
}
}
28 changes: 28 additions & 0 deletions drools-util/src/main/java/org/drools/util/Pair.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.drools.util;

public class Pair<K, V> {

private final K left;
private final V right;

public Pair(K k, V v) {
this.left = k;
this.right = v;
}

public K getLeft() {
return left;
}

public V getRight() {
return right;
}

public boolean hasLeft() {
return left != null;
}

public boolean hasRight() {
return right != null;
}
}