Skip to content

Commit

Permalink
[ruby] Nested ruby control flow structures in assign and return #4337 (
Browse files Browse the repository at this point in the history
…#4381)

* Add unless conditional

* Add nested conditionals for return and assignment.

* Scalafmt

* More accurate spans

* Tests for conditional operators.

* scalafmt

* Remove TODO comment

* PR changes
  • Loading branch information
badly-drawn-wizards authored Mar 22, 2024
1 parent 0f8d476 commit 7fe68a2
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: HashLiteral => astForHashLiteral(node)
case node: Association => astForAssociation(node)
case node: IfExpression => astForIfExpression(node)
case node: UnlessExpression => astForUnlessExpression(node)
case node: RescueExpression => astForRescueExpression(node)
case node: MandatoryParameter => astForMandatoryParameter(node)
case node: SplattingRubyNode => astForSplattingRubyNode(node)
Expand Down Expand Up @@ -257,25 +258,25 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case Some(op) =>
node.rhs match {
case cfNode: ControlFlowExpression =>
def elseAssignNil = Option {
def elseAssignNil(span: TextSpan) = Option {
ElseClause(
StatementList(
SingleAssignment(
node.lhs,
node.op,
StaticLiteral(getBuiltInType(Defines.NilClass))(node.span.spanStart("nil"))
)(node.span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil
)(node.span.spanStart(s"${node.lhs.span.text} ${node.op} nil"))
)(node.span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend"))
StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil"))
)(span.spanStart(s"${node.lhs.span.text} ${node.op} nil")) :: Nil
)(span.spanStart(s"${node.lhs.span.text} ${node.op} nil"))
)(span.spanStart(s"else\n\t${node.lhs.span.text} ${node.op} nil\nend"))
}

astForExpression(
def transform(e: RubyNode with ControlFlowExpression): RubyNode =
transformLastRubyNodeInControlFlowExpressionBody(
cfNode,
x => reassign(node.lhs, node.op, x),
e,
x => reassign(node.lhs, node.op, x, transform),
elseAssignNil
)
)
astForExpression(transform(cfNode))
case _ =>
val lhsAst = astForExpression(node.lhs)
val rhsAst = astForExpression(node.rhs)
Expand Down Expand Up @@ -303,30 +304,38 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}
}

private def reassign(lhs: RubyNode, op: String, rhs: RubyNode): RubyNode = {
private def reassign(
lhs: RubyNode,
op: String,
rhs: RubyNode,
transform: (RubyNode with ControlFlowExpression) => RubyNode
): RubyNode = {
def stmtListAssigningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match {
case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil
case (head: ControlFlowClause) :: Nil => clauseAssigningLastExpression(head) :: Nil
case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil
case head :: Nil =>
SingleAssignment(lhs, op, head)(lhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}")) :: Nil
SingleAssignment(lhs, op, head)(rhs.span.spanStart(s"${lhs.span.text} $op ${head.span.text}")) :: Nil
case Nil => List.empty
case head :: tail => head :: stmtListAssigningLastExpression(tail)
}

def clauseAssigningLastExpression(x: RubyNode with ControlFlowClause): RubyNode = x match {
case RescueClause(exceptionClassList, assignment, thenClause) =>
RescueClause(exceptionClassList, assignment, reassign(lhs, op, thenClause))(x.span)
case EnsureClause(thenClause) => EnsureClause(reassign(lhs, op, thenClause))(x.span)
case ElsIfClause(condition, thenClause) => ElsIfClause(condition, reassign(lhs, op, thenClause))(x.span)
case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause))(x.span)
RescueClause(exceptionClassList, assignment, reassign(lhs, op, thenClause, transform))(x.span)
case EnsureClause(thenClause) => EnsureClause(reassign(lhs, op, thenClause, transform))(x.span)
case ElsIfClause(condition, thenClause) =>
ElsIfClause(condition, reassign(lhs, op, thenClause, transform))(x.span)
case ElseClause(thenClause) => ElseClause(reassign(lhs, op, thenClause, transform))(x.span)
case WhenClause(matchExpressions, matchSplatExpression, thenClause) =>
WhenClause(matchExpressions, matchSplatExpression, reassign(lhs, op, thenClause))(x.span)
WhenClause(matchExpressions, matchSplatExpression, reassign(lhs, op, thenClause, transform))(x.span)
}

rhs match {
case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span)
case clause: ControlFlowClause => clauseAssigningLastExpression(clause)
case StatementList(statements) => StatementList(stmtListAssigningLastExpression(statements))(rhs.span)
case clause: ControlFlowClause => clauseAssigningLastExpression(clause)
case expr: ControlFlowExpression => transform(expr)
case _ =>
SingleAssignment(lhs, op, rhs)(lhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}"))
SingleAssignment(lhs, op, rhs)(rhs.span.spanStart(s"${lhs.span.text} $op ${rhs.span.text}"))
}
}

Expand Down Expand Up @@ -570,6 +579,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
foldIfExpression(builder)(node)
}

protected def astForUnlessExpression(node: UnlessExpression): Ast = {
val notConditionAst = UnaryExpression("!", node.condition)(node.condition.span)
astForExpression(IfExpression(notConditionAst, node.trueBranch, List(), node.falseBranch)(node.span))
}

protected def astForRescueExpression(node: RescueExpression): Ast = {
val tryAst = astForStatementList(node.body.asStatementList)
val rescueAsts = node.rescueClauses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,21 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
}

private def astsForImplicitReturnStatement(node: RubyNode): Seq[Ast] = {
def elseReturnNil = Option {
def elseReturnNil(span: TextSpan) = Option {
ElseClause(
StatementList(
ReturnExpression(StaticLiteral(getBuiltInType(Defines.NilClass))(node.span.spanStart("nil")) :: Nil)(
node.span.spanStart("return nil")
ReturnExpression(StaticLiteral(getBuiltInType(Defines.NilClass))(span.spanStart("nil")) :: Nil)(
span.spanStart("return nil")
) :: Nil
)(node.span.spanStart("return nil"))
)(node.span.spanStart("else\n\treturn nil\nend"))
)(span.spanStart("return nil"))
)(span.spanStart("else\n\treturn nil\nend"))
}

node match
case expr: ControlFlowExpression =>
astsForStatement(transformLastRubyNodeInControlFlowExpressionBody(expr, returnLastNode, elseReturnNil))
def transform(e: RubyNode with ControlFlowExpression): RubyNode =
transformLastRubyNodeInControlFlowExpressionBody(e, returnLastNode(_, transform), elseReturnNil)
astsForStatement(transform(expr))
case node: MemberCallWithBlock => returnAstForRubyCall(node)
case node: SimpleCallWithBlock => returnAstForRubyCall(node)
case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | IndexAccess | Association |
Expand Down Expand Up @@ -309,51 +311,53 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
* @return
* the RubyNode with an explicit expression
*/
private def returnLastNode(x: RubyNode): RubyNode = {
private def returnLastNode(x: RubyNode, transform: (RubyNode with ControlFlowExpression) => RubyNode): RubyNode = {
def statementListReturningLastExpression(stmts: List[RubyNode]): List[RubyNode] = stmts match {
case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil
case (head: ReturnExpression) :: Nil => head :: Nil
case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil
case Nil => List.empty
case head :: tail => head :: statementListReturningLastExpression(tail)
case (head: ControlFlowClause) :: Nil => clauseReturningLastExpression(head) :: Nil
case (head: ControlFlowExpression) :: Nil => transform(head) :: Nil
case (head: ReturnExpression) :: Nil => head :: Nil
case head :: Nil => ReturnExpression(head :: Nil)(head.span) :: Nil
case Nil => List.empty
case head :: tail => head :: statementListReturningLastExpression(tail)
}

def clauseReturningLastExpression(x: RubyNode with ControlFlowClause): RubyNode = x match {
case RescueClause(exceptionClassList, assignment, thenClause) =>
RescueClause(exceptionClassList, assignment, returnLastNode(thenClause))(x.span)
case EnsureClause(thenClause) => EnsureClause(returnLastNode(thenClause))(x.span)
case ElsIfClause(condition, thenClause) => ElsIfClause(condition, returnLastNode(thenClause))(x.span)
case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause))(x.span)
RescueClause(exceptionClassList, assignment, returnLastNode(thenClause, transform))(x.span)
case EnsureClause(thenClause) => EnsureClause(returnLastNode(thenClause, transform))(x.span)
case ElsIfClause(condition, thenClause) => ElsIfClause(condition, returnLastNode(thenClause, transform))(x.span)
case ElseClause(thenClause) => ElseClause(returnLastNode(thenClause, transform))(x.span)
case WhenClause(matchExpressions, matchSplatExpression, thenClause) =>
WhenClause(matchExpressions, matchSplatExpression, returnLastNode(thenClause))(x.span)
WhenClause(matchExpressions, matchSplatExpression, returnLastNode(thenClause, transform))(x.span)
}

x match {
case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span)
case clause: ControlFlowClause => clauseReturningLastExpression(clause)
case _ => ReturnExpression(x :: Nil)(x.span)
case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span)
case clause: ControlFlowClause => clauseReturningLastExpression(clause)
case node: ControlFlowExpression => transform(node)
case _ => ReturnExpression(x :: Nil)(x.span)
}
}

/** @param node
* \- Control Flow Expression RubyNode
* @param transform
* \- RubyNode => RubyNode function for transformation on last ruby node
* \- RubyNode => RubyNode function for transformation on the clauses of the ControlFlowExpression
* @return
* RubyNode with transform function applied
*/
protected def transformLastRubyNodeInControlFlowExpressionBody(
node: RubyNode with ControlFlowExpression,
transform: RubyNode => RubyNode,
defaultElseBranch: Option[ElseClause]
defaultElseBranch: TextSpan => Option[ElseClause]
): RubyNode = {
node match {
case RescueExpression(body, rescueClauses, elseClause, ensureClause) =>
// Ensure never returns a value, only the main body, rescue & else clauses
RescueExpression(
transform(body),
rescueClauses.map(transform),
elseClause.map(transform).orElse(defaultElseBranch),
elseClause.map(transform).orElse(defaultElseBranch(node.span)),
ensureClause
)(node.span)
case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span)
Expand All @@ -363,18 +367,22 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
condition,
transform(thenClause),
elsifClauses.map(transform),
elseClause.map(transform).orElse(defaultElseBranch)
elseClause.map(transform).orElse(defaultElseBranch(node.span))
)(node.span)
case UnlessExpression(condition, trueBranch, falseBranch) =>
UnlessExpression(condition, transform(trueBranch), falseBranch.map(transform).orElse(defaultElseBranch))(
node.span
)
UnlessExpression(
condition,
transform(trueBranch),
falseBranch.map(transform).orElse(defaultElseBranch(node.span))
)(node.span)
case ForExpression(forVariable, iterableVariable, doBlock) =>
ForExpression(forVariable, iterableVariable, transform(doBlock))(node.span)
case CaseExpression(expression, whenClauses, elseClause) =>
CaseExpression(expression, whenClauses.map(transform), elseClause.map(transform).orElse(defaultElseBranch))(
node.span
)
CaseExpression(
expression,
whenClauses.map(transform),
elseClause.map(transform).orElse(defaultElseBranch(node.span))
)(node.span)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, Local}
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.codepropertygraph.generated.nodes.Call

class ConditionalTests extends RubyCode2CpgFixture {

Expand Down Expand Up @@ -39,4 +40,38 @@ class ConditionalTests extends RubyCode2CpgFixture {
}
}

"`f(x ? y : z)` is lowered into conditional operator" in {
val cpg = code("""
|f(x ? y : z)
|""".stripMargin)
inside(cpg.call(Operators.conditional).l) {
case cond :: Nil =>
inside(cond.argument.l) {
case x :: y :: z :: Nil => {
x.code shouldBe "x"
List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z")
}
case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]")
}
}

"`f(unless x then y else z end)` is lowered into conditional operator" in {
val cpg = code("""
|f(unless x then y else z end)
|""".stripMargin)
inside(cpg.call(Operators.conditional).l) {
case cond :: Nil =>
inside(cond.argument.l) {
case x :: y :: z :: Nil => {
List(x).isCall.name(Operators.logicalNot).argument.code.l shouldBe List("x")
List(y, z).isBlock.astChildren.isIdentifier.code.l shouldBe List("y", "z")
}
case xs => fail(s"Expected exactly three arguments to conditional, got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected exactly one conditional, got [${xs.code.mkString(",")}]")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,47 @@ class MethodReturnTests extends RubyCode2CpgFixture(withDataFlow = true) {
}
}

"implicit return of nested control flow" in {
val cpg = code("""
| def f
| if true
| if true
| 1
| else
| 2
| end
| else
| if true
| 3
| else
| 4
| end
| end
| end
|""".stripMargin)

inside(cpg.method.name("f").l) {
case f :: Nil =>
inside(cpg.methodReturn.toReturn.l) {
case return1 :: return2 :: return3 :: return4 :: Nil =>
return1.code shouldBe "1"
return1.lineNumber shouldBe Some(5)

return2.code shouldBe "2"
return2.lineNumber shouldBe Some(7)

return3.code shouldBe "3"
return3.lineNumber shouldBe Some(11)

return4.code shouldBe "4"
return4.lineNumber shouldBe Some(13)

case xs => fail(s"Expected 4 returns, instead got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected exactly one method with the name `f`, instead got [${xs.code.mkString(",")}]")
}
}

"implicit RETURN node for ternary expression" in {
val cpg = code("""
|def f(x) = x ? 20 : 40
Expand Down
Loading

0 comments on commit 7fe68a2

Please sign in to comment.