diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 2c0b5ac40285..bdd5d24bc293 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -35,6 +35,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case node: HashLiteral => astForHashLiteral(node) case node: Association => astForAssociation(node) case node: IfExpression => astForIfExpression(node) + case node: UnlessExpression => astForUnlessExpression(node) case node: RescueExpression => astForRescueExpression(node) case node: MandatoryParameter => astForMandatoryParameter(node) case node: SplattingRubyNode => astForSplattingRubyNode(node) @@ -257,25 +258,25 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case Some(op) => node.rhs match { case cfNode: ControlFlowExpression => - def elseAssignNil = Option { + def elseAssignNil(span: TextSpan) = Option { ElseClause( StatementList( SingleAssignment( node.lhs, node.op, - StaticLiteral(getBuiltInType(Defines.NilClass))(node.span.spanStart("nil")) - )(node.span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil - )(node.span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) - )(node.span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend")) + StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) + )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil + )(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) + )(span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend")) } - astForExpression( + def transform(e: RubyNode with ControlFlowExpression): RubyNode = transformLastRubyNodeInControlFlowExpressionBody( - cfNode, - x => reassign(node.lhs, node.op, x), + e, + x => reassign(node.lhs, node.op, x, transform), elseAssignNil ) - ) + astForExpression(transform(cfNode)) case _ => val lhsAst = astForExpression(node.lhs) val rhsAst = astForExpression(node.rhs) @@ -303,30 +304,38 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } - private def reassign(lhs: RubyNode, op: String, rhs: RubyNode): RubyNode = { + private def reassign( + lhs: RubyNode, + op: String, + rhs: RubyNode, + transform: (RubyNode with ControlFlowExpression) => RubyNode + ): RubyNode = { def stmtListAssigningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match { - case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil + case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil + case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil case head :: Nil => - SingleAssignment(lhs, op, head)(lhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}")) :: Nil + SingleAssignment(lhs, op, head)(rhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}")) :: Nil case Nil => List.empty case head :: tail => head :: stmtListAssigningLastExpression(tail) } def clauseAssigningLastExpression(x: RubyNode with ControlFlowClause): RubyNode = x match { case RescueClause(exceptionClassList, assignment, thenClause) => - RescueClause(exceptionClassList, assignment, reassign(lhs, op, thenClause))(x.span) - case EnsureClause(thenClause) => EnsureClause(reassign(lhs, op, thenClause))(x.span) - case ElsIfClause(condition, thenClause) => ElsIfClause(condition, reassign(lhs, op, thenClause))(x.span) - case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause))(x.span) + RescueClause(exceptionClassList, assignment, reassign(lhs, op, thenClause, transform))(x.span) + case EnsureClause(thenClause) => EnsureClause(reassign(lhs, op, thenClause, transform))(x.span) + case ElsIfClause(condition, thenClause) => + ElsIfClause(condition, reassign(lhs, op, thenClause, transform))(x.span) + case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause, transform))(x.span) case WhenClause(matchExpressions, matchSplatExpression, thenClause) => - WhenClause(matchExpressions, matchSplatExpression, reassign(lhs, op, thenClause))(x.span) + WhenClause(matchExpressions, matchSplatExpression, reassign(lhs, op, thenClause, transform))(x.span) } rhs match { - case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span) - case clause: ControlFlowClause => clauseAssigningLastExpression(clause) + case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span) + case clause: ControlFlowClause => clauseAssigningLastExpression(clause) + case expr: ControlFlowExpression => transform(expr) case _ => - SingleAssignment(lhs, op, rhs)(lhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}")) + SingleAssignment(lhs, op, rhs)(rhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}")) } } @@ -570,6 +579,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { foldIfExpression(builder)(node) } + protected def astForUnlessExpression(node: UnlessExpression): Ast = { + val notConditionAst = UnaryExpression("!", node.condition)(node.condition.span) + astForExpression(IfExpression(notConditionAst, node.trueBranch, List(), node.falseBranch)(node.span)) + } + protected def astForRescueExpression(node: RescueExpression): Ast = { val tryAst = astForStatementList(node.body.asStatementList) val rescueAsts = node.rescueClauses diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 83d554c3b9be..6b048bcd95fb 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -236,19 +236,21 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t } private def astsForImplicitReturnStatement(node: RubyNode): Seq[Ast] = { - def elseReturnNil = Option { + def elseReturnNil(span: TextSpan) = Option { ElseClause( StatementList( - ReturnExpression(StaticLiteral(getBuiltInType(Defines.NilClass))(node.span.spanStart("nil")) :: Nil)( - node.span.spanStart("return nil") + ReturnExpression(StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) :: Nil)( + span.spanStart("return nil") ) :: Nil - )(node.span.spanStart("return nil")) - )(node.span.spanStart("else\n\treturn nil\nend")) + )(span.spanStart("return nil")) + )(span.spanStart("else\n\treturn nil\nend")) } node match case expr: ControlFlowExpression => - astsForStatement(transformLastRubyNodeInControlFlowExpressionBody(expr, returnLastNode, elseReturnNil)) + def transform(e: RubyNode with ControlFlowExpression): RubyNode = + transformLastRubyNodeInControlFlowExpressionBody(e, returnLastNode(_, transform), elseReturnNil) + astsForStatement(transform(expr)) case node: MemberCallWithBlock => returnAstForRubyCall(node) case node: SimpleCallWithBlock => returnAstForRubyCall(node) case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | IndexAccess | Association | @@ -309,43 +311,45 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t * @return * the RubyNode with an explicit expression */ - private def returnLastNode(x: RubyNode): RubyNode = { + private def returnLastNode(x: RubyNode, transform: (RubyNode with ControlFlowExpression) => RubyNode): RubyNode = { def statementListReturningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match { - case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil - case (head: ReturnExpression) :: Nil => head :: Nil - case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil - case Nil => List.empty - case head :: tail => head :: statementListReturningLastExpression(tail) + case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil + case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil + case (head: ReturnExpression) :: Nil => head :: Nil + case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil + case Nil => List.empty + case head :: tail => head :: statementListReturningLastExpression(tail) } def clauseReturningLastExpression(x: RubyNode with ControlFlowClause): RubyNode = x match { case RescueClause(exceptionClassList, assignment, thenClause) => - RescueClause(exceptionClassList, assignment, returnLastNode(thenClause))(x.span) - case EnsureClause(thenClause) => EnsureClause(returnLastNode(thenClause))(x.span) - case ElsIfClause(condition, thenClause) => ElsIfClause(condition, returnLastNode(thenClause))(x.span) - case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause))(x.span) + RescueClause(exceptionClassList, assignment, returnLastNode(thenClause, transform))(x.span) + case EnsureClause(thenClause) => EnsureClause(returnLastNode(thenClause, transform))(x.span) + case ElsIfClause(condition, thenClause) => ElsIfClause(condition, returnLastNode(thenClause, transform))(x.span) + case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause, transform))(x.span) case WhenClause(matchExpressions, matchSplatExpression, thenClause) => - WhenClause(matchExpressions, matchSplatExpression, returnLastNode(thenClause))(x.span) + WhenClause(matchExpressions, matchSplatExpression, returnLastNode(thenClause, transform))(x.span) } x match { - case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span) - case clause: ControlFlowClause => clauseReturningLastExpression(clause) - case _ => ReturnExpression(x :: Nil)(x.span) + case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span) + case clause: ControlFlowClause => clauseReturningLastExpression(clause) + case node: ControlFlowExpression => transform(node) + case _ => ReturnExpression(x :: Nil)(x.span) } } /** @param node * \- Control Flow Expression RubyNode * @param transform - * \- RubyNode => RubyNode function for transformation on last ruby node + * \- RubyNode => RubyNode function for transformation on the clauses of the ControlFlowExpression * @return * RubyNode with transform function applied */ protected def transformLastRubyNodeInControlFlowExpressionBody( node: RubyNode with ControlFlowExpression, transform: RubyNode => RubyNode, - defaultElseBranch: Option[ElseClause] + defaultElseBranch: TextSpan => Option[ElseClause] ): RubyNode = { node match { case RescueExpression(body, rescueClauses, elseClause, ensureClause) => @@ -353,7 +357,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t RescueExpression( transform(body), rescueClauses.map(transform), - elseClause.map(transform).orElse(defaultElseBranch), + elseClause.map(transform).orElse(defaultElseBranch(node.span)), ensureClause )(node.span) case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span) @@ -363,18 +367,22 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t condition, transform(thenClause), elsifClauses.map(transform), - elseClause.map(transform).orElse(defaultElseBranch) + elseClause.map(transform).orElse(defaultElseBranch(node.span)) )(node.span) case UnlessExpression(condition, trueBranch, falseBranch) => - UnlessExpression(condition, transform(trueBranch), falseBranch.map(transform).orElse(defaultElseBranch))( - node.span - ) + UnlessExpression( + condition, + transform(trueBranch), + falseBranch.map(transform).orElse(defaultElseBranch(node.span)) + )(node.span) case ForExpression(forVariable, iterableVariable, doBlock) => ForExpression(forVariable, iterableVariable, transform(doBlock))(node.span) case CaseExpression(expression, whenClauses, elseClause) => - CaseExpression(expression, whenClauses.map(transform), elseClause.map(transform).orElse(defaultElseBranch))( - node.span - ) + CaseExpression( + expression, + whenClauses.map(transform), + elseClause.map(transform).orElse(defaultElseBranch(node.span)) + )(node.span) } } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala index 8159757c2395..3a5c1619a2b9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ConditionalTests.scala @@ -4,6 +4,7 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Local} import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.nodes.Call class ConditionalTests extends RubyCode2CpgFixture { @@ -39,4 +40,38 @@ class ConditionalTests extends RubyCode2CpgFixture { } } + "`f(x ? y : z)` is lowered into conditional operator" in { + val cpg = code(""" + |f(x ? y : z) + |""".stripMargin) + inside(cpg.call(Operators.conditional).l) { + case cond :: Nil => + inside(cond.argument.l) { + case x :: y :: z :: Nil => { + x.code shouldBe "x" + List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z") + } + case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]") + } + } + + "`f(unless x then y else z end)` is lowered into conditional operator" in { + val cpg = code(""" + |f(unless x then y else z end) + |""".stripMargin) + inside(cpg.call(Operators.conditional).l) { + case cond :: Nil => + inside(cond.argument.l) { + case x :: y :: z :: Nil => { + List(x).isCall.name(Operators.logicalNot).argument.code.l shouldBe List("x") + List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z") + } + case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]") + } + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala index 7c42a4784025..b89f8fad8105 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala @@ -240,6 +240,47 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) { } } + "implicit return of nested control flow" in { + val cpg = code(""" + | def f + | if true + | if true + | 1 + | else + | 2 + | end + | else + | if true + | 3 + | else + | 4 + | end + | end + | end + |""".stripMargin) + + inside(cpg.method.name("f").l) { + case f :: Nil => + inside(cpg.methodReturn.toReturn.l) { + case return1 :: return2 :: return3 :: return4 :: Nil => + return1.code shouldBe "1" + return1.lineNumber shouldBe Some(5) + + return2.code shouldBe "2" + return2.lineNumber shouldBe Some(7) + + return3.code shouldBe "3" + return3.lineNumber shouldBe Some(11) + + return4.code shouldBe "4" + return4.lineNumber shouldBe Some(13) + + case xs => fail(s"Expected 4 returns, instead got [${xs.code.mkString(",")}]") + } + case xs => fail(s"Expected exactly one method with the name `f`, instead got [${xs.code.mkString(",")}]") + } + } + "implicit RETURN node for ternary expression" in { val cpg = code(""" |def f(x) = x ? 20 : 40 diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala index 908a09cdb98b..8b0bae392100 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/SingleAssignmentTests.scala @@ -120,7 +120,7 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { two.code shouldBe "2" } - "`if-else-end` on the RHS of an assignment is represented by a `conditional` operator call" in { + "`if-else-end` on the RHS of an assignment" in { val cpg = code(""" |x = if true then 20 else 40 end |""".stripMargin) @@ -142,4 +142,75 @@ class SingleAssignmentTests extends RubyCode2CpgFixture { rhsElseBranchValue.code shouldBe "40" } + "nested if-else-end on the RHS of an assignment" in { + val cpg = code(""" + |x = if true + | if true + | 1 + | else + | 2 + | end + |else + | if true + | 3 + | else + | 4 + | end + |end + | + |""".stripMargin) + + inside(cpg.assignment.l) { + case assign1 :: assign2 :: assign3 :: assign4 :: Nil => + assign1.lineNumber shouldBe Some(4) + assign1.argument(1).code shouldBe "x" + assign1.argument(2).code shouldBe "1" + assign1.argument(2).lineNumber shouldBe Some(4) + + assign2.lineNumber shouldBe Some(5) + assign2.argument(1).code shouldBe "x" + assign2.argument(2).code shouldBe "2" + assign2.argument(2).lineNumber shouldBe Some(6) + + assign3.lineNumber shouldBe Some(10) + assign3.argument(1).code shouldBe "x" + assign3.argument(2).code shouldBe "3" + assign3.argument(2).lineNumber shouldBe Some(10) + + assign4.lineNumber shouldBe Some(11) + assign4.argument(1).code shouldBe "x" + assign4.argument(2).code shouldBe "4" + assign4.argument(2).lineNumber shouldBe Some(12) + case xs => fail(s"Expected 4 assignments, instead got [${xs.code.mkString(",")}]") + } + + } + + "nested if-end should have implicit elses" in { + val cpg = code(""" + |x = if true + | if true + | 1 + | end + |end + |""".stripMargin) + + val assigns = cpg.assignment.l + inside(cpg.assignment.l) { + case assign1 :: assignNil1 :: assignNil2 :: Nil => + assign1.argument(1).code shouldBe "x" + assign1.argument(2).code shouldBe "1" + assign1.lineNumber shouldBe Some(4) + + assignNil1.argument(1).code shouldBe "x" + assignNil1.argument(2).code shouldBe "nil" + assignNil1.lineNumber shouldBe Some(3) + + assignNil2.argument(1).code shouldBe "x" + assignNil2.argument(2).code shouldBe "nil" + assignNil2.lineNumber shouldBe Some(2) + case xs => fail(s"Expected 3 assignments, instead got [${xs.code.mkString(",")}]") + } + } + }