diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index 4f1a94d46e0f..e35f05e8730a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -200,12 +200,12 @@ class AstCreator( case ctx: ProcDefinitionPrimaryContext => astForProcDefinitionContext(ctx.procDefinition()) case ctx: YieldWithOptionalArgumentPrimaryContext => astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments())) - case ctx: IfExpressionPrimaryContext => Seq(astForIfExpression(ctx.ifExpression())) - case ctx: UnlessExpressionPrimaryContext => Seq(astForUnlessExpression(ctx.unlessExpression())) + case ctx: IfExpressionPrimaryContext => astForIfExpression(ctx.ifExpression()) + case ctx: UnlessExpressionPrimaryContext => astForUnlessExpression(ctx.unlessExpression()) case ctx: CaseExpressionPrimaryContext => astForCaseExpressionPrimaryContext(ctx) - case ctx: WhileExpressionPrimaryContext => Seq(astForWhileExpression(ctx.whileExpression())) - case ctx: UntilExpressionPrimaryContext => Seq(astForUntilExpression(ctx.untilExpression())) - case ctx: ForExpressionPrimaryContext => Seq(astForForExpression(ctx.forExpression())) + case ctx: WhileExpressionPrimaryContext => astForWhileExpression(ctx.whileExpression()) + case ctx: UntilExpressionPrimaryContext => astForUntilExpression(ctx.untilExpression()) + case ctx: ForExpressionPrimaryContext => astForForExpression(ctx.forExpression()) case ctx: ReturnWithParenthesesPrimaryContext => Seq(returnAst(returnNode(ctx, text(ctx)), astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses()))) case ctx: JumpExpressionPrimaryContext => astForJumpExpressionPrimaryContext(ctx) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala index e3da27b4c80d..231bb6d6ad4b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala @@ -223,6 +223,21 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As (as.toSeq, bs.toSeq) } + /** Partitions a sequence of Ast objects into the boilerplate for do-block functions and the call node at the end. + * + * @return + * a tuple where the first element is the closure boilerplate and the latter is the last expression. + */ + def partitionClosureFromExpr: (Seq[Ast], Option[Ast]) = { + val (as, bs) = a.partition(_.root match + case Some(_: NewMethod) => true + case Some(_: NewTypeDecl) => true + case Some(x: NewCall) if x.name.startsWith(Operators.assignment) => true + case _ => false + ) + (as.toSeq, bs.lastOption) + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala index 717235dc7fa1..84d9ed6baf2b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -132,4 +132,55 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo tryCatchAst(tryNode, tryBodyAst, catchAsts, finallyAst) } + protected def astForUntilExpression(ctx: UntilExpressionContext): Seq[Ast] = { + val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr + val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + // TODO: testAst should be negated if it's going to be modelled as a while stmt. + boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) + } + + protected def astForForExpression(ctx: ForExpressionContext): Seq[Ast] = { + val forVarAst = astForForVariableContext(ctx.forVariable()) + val (boilerplate, forExprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr + val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + // TODO: for X in Y is not properly modelled by while Y + val forRootAst = whileAst(forExprAst, forBodyAst, Some(text(ctx)), line(ctx), column(ctx)) + boilerplate :+ forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst) + } + + private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = { + if (ctx.singleLeftHandSide() != null) { + astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) + } else if (ctx.multipleLeftHandSide() != null) { + astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) + } else { + Seq(Ast()) + } + } + + protected def astForWhileExpression(ctx: WhileExpressionContext): Seq[Ast] = { + val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr + val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) + boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) + } + + protected def astForIfExpression(ctx: IfExpressionContext): Seq[Ast] = { + val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr + val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).flatMap(astForElsifClause) + val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq()) + val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) + boilerplate :+ controlStructureAst(ifNode, exprAst) + .withChildren(thenAst) + .withChildren(elsifAsts.toSeq) + .withChildren(elseAst) + } + + private def astForElsifClause(ctx: ElsifClauseContext): Seq[Ast] = { + val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) + val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr + val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + boilerplate :+ controlStructureAst(ifNode, exprAst, bodyAst) + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index add72fd7ea5f..87d670eeb3c3 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -395,57 +395,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { otherAst :+ callAst(call, argsAst) } - protected def astForUntilExpression(ctx: UntilExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()).headOption - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: testAst should be negated if it's going to be modelled as a while stmt. - whileAst(testAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForForExpression(ctx: ForExpressionContext): Ast = { - val forVarAst = astForForVariableContext(ctx.forVariable()) - val forExprAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - // TODO: for X in Y is not properly modelled by while Y - val forRootAst = whileAst(forExprAst.headOption, forBodyAst, Some(text(ctx)), line(ctx), column(ctx)) - forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst) - } - - private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = { - if (ctx.singleLeftHandSide() != null) { - astForSingleLeftHandSideContext(ctx.singleLeftHandSide()) - } else if (ctx.multipleLeftHandSide() != null) { - astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide()) - } else { - Seq(Ast()) - } - } - - protected def astForWhileExpression(ctx: WhileExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement()) - whileAst(testAst.headOption, bodyAst, Some(text(ctx)), line(ctx), column(ctx)) - } - - protected def astForIfExpression(ctx: IfExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).map(astForElsifClause) - val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq()) - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - controlStructureAst(ifNode, testAst.headOption) - .withChildren(thenAst) - .withChildren(elsifAsts.toSeq) - .withChildren(elseAst) - } - - private def astForElsifClause(ctx: ElsifClauseContext): Ast = { - val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) - controlStructureAst(ifNode, testAst.headOption, bodyAst) - } - protected def astForVariableReference(ctx: VariableReferenceContext): Ast = ctx match { case ctx: VariableIdentifierVariableReferenceContext => astForVariableIdentifierHelper(ctx.variableIdentifier()) case ctx: PseudoVariableIdentifierVariableReferenceContext => @@ -501,13 +450,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } - protected def astForUnlessExpression(ctx: UnlessExpressionContext): Ast = { - val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()) - val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) + protected def astForUnlessExpression(ctx: UnlessExpressionContext): Seq[Ast] = { + val (exprAst, otherAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionExprAst + val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement()) val elseAst = Option(ctx.elseClause()).map(_.compoundStatement()).map(st => astForCompoundStatement(st)).getOrElse(Seq()) val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx)) - controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst) + otherAst :+ controlStructureAst(ifNode, exprAst.headOption, thenAst ++ elseAst) } protected def astForQuotedStringExpression(ctx: QuotedStringExpressionContext): Seq[Ast] = ctx match diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala index 1b93a7f687d3..ac378c86f80c 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/DoBlockTest.scala @@ -1,8 +1,9 @@ package io.joern.rubysrc2cpg.passes.ast import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, MethodRef} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, ControlStructure, Identifier, MethodRef} import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes} class DoBlockTest extends RubyCode2CpgFixture { @@ -107,4 +108,25 @@ class DoBlockTest extends RubyCode2CpgFixture { } } + "a boolean do-block function as a conditional argument" should { + val cpg = code(""" + |if @items.any? { |x| x > 1 } + | puts "foo" + |else + | puts "bar" + |end + |""".stripMargin) + + "be defined outside of the control structure" in { + val anyMethod = cpg.method.name("any.*").head + anyMethod.astParent.label shouldBe NodeTypes.BLOCK + } + + "have the call to the method ref as the conditional argument" in { + val ifStmt = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).head + val ::(anyCall: Call, _) = ifStmt.condition.l: @unchecked + anyCall.name should startWith("any") + } + } + }