diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index 1d0e54eb6e52..199116b715e9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -240,23 +240,23 @@ class AstCreator( def astForExpressionContext(ctx: ExpressionContext): Seq[Ast] = ctx match { case ctx: PrimaryExpressionContext => astForPrimaryContext(ctx.primary()) - case ctx: UnaryExpressionContext => astForUnaryExpression(ctx) - case ctx: PowerExpressionContext => astForPowerExpression(ctx) - case ctx: UnaryMinusExpressionContext => astForUnaryMinusExpression(ctx) - case ctx: MultiplicativeExpressionContext => astForMultiplicativeExpression(ctx) - case ctx: AdditiveExpressionContext => astForAdditiveExpression(ctx) - case ctx: BitwiseShiftExpressionContext => astForBitwiseShiftExpression(ctx) - case ctx: BitwiseAndExpressionContext => astForBitwiseAndExpression(ctx) - case ctx: BitwiseOrExpressionContext => astForBitwiseOrExpression(ctx) - case ctx: RelationalExpressionContext => astForRelationalExpression(ctx) - case ctx: EqualityExpressionContext => astForEqualityExpression(ctx) - case ctx: OperatorAndExpressionContext => astForAndExpression(ctx) - case ctx: OperatorOrExpressionContext => astForOrExpression(ctx) + case ctx: UnaryExpressionContext => Seq(astForUnaryExpression(ctx)) + case ctx: PowerExpressionContext => Seq(astForPowerExpression(ctx)) + case ctx: UnaryMinusExpressionContext => Seq(astForUnaryMinusExpression(ctx)) + case ctx: MultiplicativeExpressionContext => Seq(astForMultiplicativeExpression(ctx)) + case ctx: AdditiveExpressionContext => Seq(astForAdditiveExpression(ctx)) + case ctx: BitwiseShiftExpressionContext => Seq(astForBitwiseShiftExpression(ctx)) + case ctx: BitwiseAndExpressionContext => Seq(astForBitwiseAndExpression(ctx)) + case ctx: BitwiseOrExpressionContext => Seq(astForBitwiseOrExpression(ctx)) + case ctx: RelationalExpressionContext => Seq(astForRelationalExpression(ctx)) + case ctx: EqualityExpressionContext => Seq(astForEqualityExpression(ctx)) + case ctx: OperatorAndExpressionContext => Seq(astForAndExpression(ctx)) + case ctx: OperatorOrExpressionContext => Seq(astForOrExpression(ctx)) case ctx: RangeExpressionContext => astForRangeExpressionContext(ctx) case ctx: ConditionalOperatorExpressionContext => Seq(astForTernaryConditionalOperator(ctx)) case ctx: SingleAssignmentExpressionContext => astForSingleAssignmentExpressionContext(ctx) case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx) - case ctx: IsDefinedExpressionContext => astForIsDefinedExpression(ctx) + case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx)) case _ => logger.error(s"astForExpressionContext() $relativeFilename, ${text(ctx)} All contexts mismatched.") Seq(Ast()) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 62ac83d54958..51f9917d7f20 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -4,14 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.* import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{ - AstNodeNew, - NewCall, - NewIdentifier, - NewMethod, - NewType, - NewTypeDecl -} +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} import org.antlr.v4.runtime.ParserRuleContext import org.slf4j.LoggerFactory @@ -24,45 +17,45 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { private val logger = LoggerFactory.getLogger(this.getClass) protected var lastModifier: Option[String] = None - protected def astForPowerExpression(ctx: PowerExpressionContext): Seq[Ast] = + protected def astForPowerExpression(ctx: PowerExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.exponentiation, ctx.expression().asScala) - protected def astForOrExpression(ctx: OperatorOrExpressionContext): Seq[Ast] = + protected def astForOrExpression(ctx: OperatorOrExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.or, ctx.expression().asScala) - protected def astForAndExpression(ctx: OperatorAndExpressionContext): Seq[Ast] = + protected def astForAndExpression(ctx: OperatorAndExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.and, ctx.expression().asScala) - protected def astForUnaryExpression(ctx: UnaryExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForUnaryExpression(ctx: UnaryExpressionContext): Ast = ctx.op.getType match { case TILDE => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) case PLUS => astForBinaryOperatorExpression(ctx, Operators.plus, Seq(ctx.expression())) case EMARK => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression())) } - protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Seq[Ast] = + protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.minus, Seq(ctx.expression())) - protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Ast = ctx.op.getType match { case PLUS => astForBinaryOperatorExpression(ctx, Operators.addition, ctx.expression().asScala) case MINUS => astForBinaryOperatorExpression(ctx, Operators.subtraction, ctx.expression().asScala) } - protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Ast = ctx.op.getType match { case STAR => astForMultiplicativeStarExpression(ctx) case SLASH => astForMultiplicativeSlashExpression(ctx) case PERCENT => astForMultiplicativePercentExpression(ctx) } - protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.multiplication, ctx.expression().asScala) - protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.division, ctx.expression().asScala) - protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = + protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.modulo, ctx.expression().asScala) - protected def astForEqualityExpression(ctx: EqualityExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForEqualityExpression(ctx: EqualityExpressionContext): Ast = ctx.op.getType match { case LTEQGT => astForBinaryOperatorExpression(ctx, Operators.compare, ctx.expression().asScala) case EQ2 => astForBinaryOperatorExpression(ctx, Operators.equals, ctx.expression().asScala) case EQ3 => astForBinaryOperatorExpression(ctx, Operators.is, ctx.expression().asScala) @@ -71,22 +64,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case EMARKTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.notPatternMatch, ctx.expression().asScala) } - protected def astForRelationalExpression(ctx: RelationalExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForRelationalExpression(ctx: RelationalExpressionContext): Ast = ctx.op.getType match { case GT => astForBinaryOperatorExpression(ctx, Operators.greaterThan, ctx.expression().asScala) case GTEQ => astForBinaryOperatorExpression(ctx, Operators.greaterEqualsThan, ctx.expression().asScala) case LT => astForBinaryOperatorExpression(ctx, Operators.lessThan, ctx.expression().asScala) case LTEQ => astForBinaryOperatorExpression(ctx, Operators.lessEqualsThan, ctx.expression().asScala) } - protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Ast = ctx.op.getType match { case BAR => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) case CARET => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala) } - protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Seq[Ast] = + protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Ast = astForBinaryOperatorExpression(ctx, Operators.logicalAnd, ctx.expression().asScala) - protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Seq[Ast] = ctx.op.getType match { + protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Ast = ctx.op.getType match { case LT2 => astForBinaryOperatorExpression(ctx, Operators.shiftLeft, ctx.expression().asScala) case GT2 => astForBinaryOperatorExpression(ctx, Operators.logicalShiftRight, ctx.expression().asScala) } @@ -95,19 +88,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ctx: ParserRuleContext, name: String, arguments: Iterable[ExpressionContext] - ): Seq[Ast] = { - val (argsAst, otherAst) = arguments - .flatMap(astForExpressionContext) - .partition(_.root match - case Some(_: NewMethod) => false - case Some(_: NewTypeDecl) => false - case _ => true - ) - val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH) - otherAst.toSeq :+ callAst(call, argsAst.toList) + ): Ast = { + val argsAst = arguments.flatMap(astForExpressionContext) + val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH) + callAst(call, argsAst.toList) } - protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Seq[Ast] = + protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Ast = astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression())) // TODO: Maybe merge (in RubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression? @@ -375,7 +362,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] = - astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala) + Seq(astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)) protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = { val argsAst = Option(ctx.argumentsWithParentheses()) match diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala index de3cac24421f..4361d21718c0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala @@ -31,23 +31,4 @@ class DoBlockTest extends RubyCode2CpgFixture { } } - "a do-block function used as a higher-order function" should { - val cpg = code("""class TransactionsController < ApplicationController - | def permitted_column_name(column_name) - | %w[trx_date description amount].find { |permitted| column_name == permitted } || 'trx_date' - | end - |end - | - |""".stripMargin) - - "create a do-block method named from the surrounding function" in { - val findMethod :: _ = cpg.method.name("find.*").l: @unchecked - findMethod.name should startWith("find") - findMethod.parameter.size shouldBe 1 - val permitParam :: _ = findMethod.parameter.l: @unchecked - permitParam.name shouldBe "permitted" - } - - } - }