Skip to content

Commit

Permalink
[rubysrc2cpg] Do-Block Function as Conditional (#3729)
Browse files Browse the repository at this point in the history
Fixed bug regarding do-block functions as control structure conditionals
  • Loading branch information
DavidBakerEffendi authored Oct 10, 2023
1 parent cd00f2d commit c726bb5
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ class AstCreator(
case ctx: ProcDefinitionPrimaryContext => astForProcDefinitionContext(ctx.procDefinition())
case ctx: YieldWithOptionalArgumentPrimaryContext =>
astForYieldCall(ctx, Option(ctx.yieldWithOptionalArgument().arguments()))
case ctx: IfExpressionPrimaryContext => Seq(astForIfExpression(ctx.ifExpression()))
case ctx: UnlessExpressionPrimaryContext => Seq(astForUnlessExpression(ctx.unlessExpression()))
case ctx: IfExpressionPrimaryContext => astForIfExpression(ctx.ifExpression())
case ctx: UnlessExpressionPrimaryContext => astForUnlessExpression(ctx.unlessExpression())
case ctx: CaseExpressionPrimaryContext => astForCaseExpressionPrimaryContext(ctx)
case ctx: WhileExpressionPrimaryContext => Seq(astForWhileExpression(ctx.whileExpression()))
case ctx: UntilExpressionPrimaryContext => Seq(astForUntilExpression(ctx.untilExpression()))
case ctx: ForExpressionPrimaryContext => Seq(astForForExpression(ctx.forExpression()))
case ctx: WhileExpressionPrimaryContext => astForWhileExpression(ctx.whileExpression())
case ctx: UntilExpressionPrimaryContext => astForUntilExpression(ctx.untilExpression())
case ctx: ForExpressionPrimaryContext => astForForExpression(ctx.forExpression())
case ctx: ReturnWithParenthesesPrimaryContext =>
Seq(returnAst(returnNode(ctx, text(ctx)), astForArgumentsWithParenthesesContext(ctx.argumentsWithParentheses())))
case ctx: JumpExpressionPrimaryContext => astForJumpExpressionPrimaryContext(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
(as.toSeq, bs.toSeq)
}

/** Partitions a sequence of Ast objects into the boilerplate for do-block functions and the call node at the end.
*
* @return
* a tuple where the first element is the closure boilerplate and the latter is the last expression.
*/
def partitionClosureFromExpr: (Seq[Ast], Option[Ast]) = {
val (as, bs) = a.partition(_.root match
case Some(_: NewMethod) => true
case Some(_: NewTypeDecl) => true
case Some(x: NewCall) if x.name.startsWith(Operators.assignment) => true
case _ => false
)
(as.toSeq, bs.lastOption)
}

}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,55 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
tryCatchAst(tryNode, tryBodyAst, catchAsts, finallyAst)
}

protected def astForUntilExpression(ctx: UntilExpressionContext): Seq[Ast] = {
val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr
val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
// TODO: testAst should be negated if it's going to be modelled as a while stmt.
boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx))
}

protected def astForForExpression(ctx: ForExpressionContext): Seq[Ast] = {
val forVarAst = astForForVariableContext(ctx.forVariable())
val (boilerplate, forExprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr
val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
// TODO: for X in Y is not properly modelled by while Y
val forRootAst = whileAst(forExprAst, forBodyAst, Some(text(ctx)), line(ctx), column(ctx))
boilerplate :+ forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst)
}

private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = {
if (ctx.singleLeftHandSide() != null) {
astForSingleLeftHandSideContext(ctx.singleLeftHandSide())
} else if (ctx.multipleLeftHandSide() != null) {
astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide())
} else {
Seq(Ast())
}
}

protected def astForWhileExpression(ctx: WhileExpressionContext): Seq[Ast] = {
val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr
val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
boilerplate :+ whileAst(exprAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx))
}

protected def astForIfExpression(ctx: IfExpressionContext): Seq[Ast] = {
val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr
val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).flatMap(astForElsifClause)
val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq())
val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx))
boilerplate :+ controlStructureAst(ifNode, exprAst)
.withChildren(thenAst)
.withChildren(elsifAsts.toSeq)
.withChildren(elseAst)
}

private def astForElsifClause(ctx: ElsifClauseContext): Seq[Ast] = {
val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx))
val (boilerplate, exprAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionClosureFromExpr
val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
boilerplate :+ controlStructureAst(ifNode, exprAst, bodyAst)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -395,57 +395,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
otherAst :+ callAst(call, argsAst)
}

protected def astForUntilExpression(ctx: UntilExpressionContext): Ast = {
val testAst = astForExpressionOrCommand(ctx.expressionOrCommand()).headOption
val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
// TODO: testAst should be negated if it's going to be modelled as a while stmt.
whileAst(testAst, bodyAst, Some(text(ctx)), line(ctx), column(ctx))
}

protected def astForForExpression(ctx: ForExpressionContext): Ast = {
val forVarAst = astForForVariableContext(ctx.forVariable())
val forExprAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val forBodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
// TODO: for X in Y is not properly modelled by while Y
val forRootAst = whileAst(forExprAst.headOption, forBodyAst, Some(text(ctx)), line(ctx), column(ctx))
forVarAst.headOption.map(forRootAst.withChild).getOrElse(forRootAst)
}

private def astForForVariableContext(ctx: ForVariableContext): Seq[Ast] = {
if (ctx.singleLeftHandSide() != null) {
astForSingleLeftHandSideContext(ctx.singleLeftHandSide())
} else if (ctx.multipleLeftHandSide() != null) {
astForMultipleLeftHandSideContext(ctx.multipleLeftHandSide())
} else {
Seq(Ast())
}
}

protected def astForWhileExpression(ctx: WhileExpressionContext): Ast = {
val testAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val bodyAst = astForCompoundStatement(ctx.doClause().compoundStatement())
whileAst(testAst.headOption, bodyAst, Some(text(ctx)), line(ctx), column(ctx))
}

protected def astForIfExpression(ctx: IfExpressionContext): Ast = {
val testAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
val elsifAsts = Option(ctx.elsifClause).map(_.asScala).getOrElse(Seq()).map(astForElsifClause)
val elseAst = Option(ctx.elseClause()).map(ctx => astForCompoundStatement(ctx.compoundStatement())).getOrElse(Seq())
val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx))
controlStructureAst(ifNode, testAst.headOption)
.withChildren(thenAst)
.withChildren(elsifAsts.toSeq)
.withChildren(elseAst)
}

private def astForElsifClause(ctx: ElsifClauseContext): Ast = {
val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx))
val testAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val bodyAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
controlStructureAst(ifNode, testAst.headOption, bodyAst)
}

protected def astForVariableReference(ctx: VariableReferenceContext): Ast = ctx match {
case ctx: VariableIdentifierVariableReferenceContext => astForVariableIdentifierHelper(ctx.variableIdentifier())
case ctx: PseudoVariableIdentifierVariableReferenceContext =>
Expand Down Expand Up @@ -501,13 +450,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}
}

protected def astForUnlessExpression(ctx: UnlessExpressionContext): Ast = {
val testAst = astForExpressionOrCommand(ctx.expressionOrCommand())
val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
protected def astForUnlessExpression(ctx: UnlessExpressionContext): Seq[Ast] = {
val (exprAst, otherAst) = astForExpressionOrCommand(ctx.expressionOrCommand()).partitionExprAst
val thenAst = astForCompoundStatement(ctx.thenClause().compoundStatement())
val elseAst =
Option(ctx.elseClause()).map(_.compoundStatement()).map(st => astForCompoundStatement(st)).getOrElse(Seq())
val ifNode = controlStructureNode(ctx, ControlStructureTypes.IF, text(ctx))
controlStructureAst(ifNode, testAst.headOption, thenAst ++ elseAst)
otherAst :+ controlStructureAst(ifNode, exprAst.headOption, thenAst ++ elseAst)
}

protected def astForQuotedStringExpression(ctx: QuotedStringExpressionContext): Seq[Ast] = ctx match
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package io.joern.rubysrc2cpg.passes.ast

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, MethodRef}
import io.shiftleft.codepropertygraph.generated.nodes.{Call, ControlStructure, Identifier, MethodRef}
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes}

class DoBlockTest extends RubyCode2CpgFixture {

Expand Down Expand Up @@ -107,4 +108,25 @@ class DoBlockTest extends RubyCode2CpgFixture {
}
}

"a boolean do-block function as a conditional argument" should {
val cpg = code("""
|if @items.any? { |x| x > 1 }
| puts "foo"
|else
| puts "bar"
|end
|""".stripMargin)

"be defined outside of the control structure" in {
val anyMethod = cpg.method.name("any.*").head
anyMethod.astParent.label shouldBe NodeTypes.BLOCK
}

"have the call to the method ref as the conditional argument" in {
val ifStmt = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).head
val ::(anyCall: Call, _) = ifStmt.condition.l: @unchecked
anyCall.name should startWith("any")
}
}

}

0 comments on commit c726bb5

Please sign in to comment.