Skip to content

Commit

Permalink
[ruby] when Statements (Switch Statements) #3926 (#4029)
Browse files Browse the repository at this point in the history
This draft translates ruby `case` expressions into `if-elif-...-else-end` chains. 

The expression matched against is assigned a temporary variable `<tmp-#>` where # is a number. I was not sure on the convention but I wanted to be sure the variable is always fresh.

A single when can contain a list of expressions to match against, which I translate into an `or-expression` if there is more than one. When matching against an expression `mExpr`, it is turned into a condition with `mExpr.=== <tmp-#>`.

This list of expressions can contain a splat at the end, which we aren't handling yet.

It also remains maybe to special case to a switch ast if all match expressions are literals.

TextSpans for the generated intermediate ast are not sane and will need to be considered more carefully.

---------

Co-authored-by: David Baker Effendi <[email protected]>
  • Loading branch information
badly-drawn-wizards and DavidBakerEffendi authored Jan 12, 2024
1 parent 438a1c9 commit 21558c7
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AstCreator(protected val filename: String, parser: ResourceManagedParser,
with AstForExpressionsCreator
with AstForFunctionsCreator
with AstForTypesCreator
with FreshVariableCreator
with AstNodeBuilder[RubyNode, AstCreator] {

protected val logger: Logger = LoggerFactory.getLogger(getClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral, NewControlStructure}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, ControlStructureTypes}
import io.shiftleft.semanticcpg.language.NodeOrdering.nodeList
import scala.collection.mutable
import io.joern.rubysrc2cpg.parser.RubyParser

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case node: UntilExpression => astForUntilStatement(node) :: Nil
case node: IfExpression => astForIfStatement(node) :: Nil
case node: UnlessExpression => astForUnlessStatement(node) :: Nil
case node: CaseExpression => astsForCaseExpression(node)
case node: StatementList => astForStatementList(node) :: Nil
case node: SimpleCallWithBlock => astForSimpleCallWithBlock(node) :: Nil
case node: MemberCallWithBlock => astForMemberCallWithBlock(node) :: Nil
Expand Down Expand Up @@ -100,6 +101,51 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
controlStructureAst(ifNode, Some(notConditionAst), thenAst :: elseAsts)
}

protected def astsForCaseExpression(node: CaseExpression): Seq[Ast] = {
def goCase(expr: Option[SimpleIdentifier]): List[RubyNode] = {
val elseThenClause: Option[RubyNode] = node.elseClause.map(_.asInstanceOf[ElseClause].thenClause)
val whenClauses = node.whenClauses.map(_.asInstanceOf[WhenClause])
val ifElseChain = whenClauses.foldRight[Option[RubyNode]](elseThenClause) {
(whenClause: WhenClause, restClause: Option[RubyNode]) =>
// We translate multiple match expressions into an or expression.
// There may be a splat as the last match expression, which is currently parsed as unknown
// A single match expression is compared using `.===` to the case target expression if it is present
// otherwise it is treated as a conditional.
val conditions = whenClause.matchExpressions.map { mExpr =>
expr.map(e => MemberCall(mExpr, ".", "===", List(e))(mExpr.span)).getOrElse(mExpr)
} ++ (whenClause.matchSplatExpression.iterator.flatMap {
case u: Unknown => List(u)
case e =>
logger.warn("Splatting not implemented for `when` in ruby `case`")
List(Unknown()(e.span))
})
// There is always at least one match expression or a splat
// a splat will become an unknown in condition at the end
val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) =>
BinaryExpression(cond, "||", condAcc)(whenClause.span)
}
val conditional = IfExpression(
condition,
whenClause.thenClause.asStatementList,
List(),
restClause.map { els => ElseClause(els.asStatementList)(els.span) }
)(node.span)
Some(conditional)
}
ifElseChain.iterator.toList
}
def generatedNode: StatementList = node.expression
.map { e =>
val tmp = SimpleIdentifier(None)(e.span.spanStart(freshName))
StatementList(
List(SingleAssignment(tmp, "=", e)(e.span)) ++
goCase(Some(tmp))
)(node.span)
}
.getOrElse(StatementList(goCase(None))(node.span))
astsForStatement(generatedNode)
}

protected def astForStatementList(node: StatementList): Ast = {
val block = blockNode(node)
scope.pushNewScope(block)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:
node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList)
val classBodyAsts = classBody.statements.flatMap(astsForStatement) match {
case bodyAsts if shouldGenerateDefaultConstructorStack.head =>
val bodyStart = classBody.span.spanStart
val bodyStart = classBody.span.spanStart()
val initBody = StatementList(List())(bodyStart)
val methodDecl = astForMethodDeclaration(MethodDeclaration("<init>", List(), initBody)(bodyStart))
methodDecl :: bodyAsts
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.astcreation.AstCreator
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.*

trait FreshVariableCreator { this: AstCreator =>
// This is in a single-threaded context.
var tmpCounter: Int = 0
private def tmpTemplate(id: Int): String = s"<tmp-${id}>"
protected def freshName: String = {
val name = tmpTemplate(tmpCounter)
tmpCounter += 1
name
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object RubyIntermediateAst {
columnEnd: Option[Integer],
text: String
) {
def spanStart: TextSpan = TextSpan(line, column, line, column, "")
def spanStart(newText: String = ""): TextSpan = TextSpan(line, column, line, column, newText)
}

sealed class RubyNode(val span: TextSpan) {
Expand All @@ -28,6 +28,7 @@ object RubyIntermediateAst {
def asStatementList = node match
case stmtList: StatementList => stmtList
case _ => StatementList(List(node))(node.span)

}

final case class Unknown()(span: TextSpan) extends RubyNode(span)
Expand Down Expand Up @@ -119,6 +120,20 @@ object RubyIntermediateAst {
span: TextSpan
) extends RubyNode(span)

final case class CaseExpression(
expression: Option[RubyNode],
whenClauses: List[RubyNode],
elseClause: Option[RubyNode]
)(span: TextSpan)
extends RubyNode(span)

final case class WhenClause(
matchExpressions: List[RubyNode],
matchSplatExpression: Option[RubyNode],
thenClause: RubyNode
)(span: TextSpan)
extends RubyNode(span)

final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span)

/** Represents an unqualified identifier e.g. `X`, `x`, `@x`, `@@x`, `$x`, `$<`, etc. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import org.antlr.v4.runtime.tree.{ErrorNode, ParseTree, TerminalNode}

import scala.jdk.CollectionConverters.*
import io.joern.rubysrc2cpg.parser.RubyParser.RescueClauseContext
import io.joern.rubysrc2cpg.parser.RubyParser.EnsureClauseContext
import io.joern.rubysrc2cpg.parser.RubyParser.ExceptionClassListContext
import org.antlr.v4.runtime.tree.RuleNode
import io.joern.rubysrc2cpg.parser.RubyParser.SplattingArgumentContext

/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST.
*/
Expand Down Expand Up @@ -555,22 +553,47 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}
}

override def visitExceptionClassList(ctx: ExceptionClassListContext): RubyNode = {
override def visitExceptionClassList(ctx: RubyParser.ExceptionClassListContext): RubyNode = {
// Requires implementing multiple rhs with splatting
Unknown()(ctx.toTextSpan)
}

override def visitRescueClause(ctx: RescueClauseContext): RubyNode = {
override def visitRescueClause(ctx: RubyParser.RescueClauseContext): RubyNode = {
val exceptionClassList = Option(ctx.exceptionClassList).map(visit)
val elseClause = Option(ctx.exceptionVariableAssignment).map(visit)
val thenClause = visit(ctx.thenClause)
RescueClause(exceptionClassList, elseClause, thenClause)(ctx.toTextSpan)
}

override def visitEnsureClause(ctx: EnsureClauseContext): RubyNode = {
override def visitEnsureClause(ctx: RubyParser.EnsureClauseContext): RubyNode = {
EnsureClause(visit(ctx.compoundStatement()))(ctx.toTextSpan)
}

override def visitCaseWithExpression(ctx: RubyParser.CaseWithExpressionContext): RubyNode = {
val expression = Option(ctx.commandOrPrimaryValue()).map(visit)
val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList)
val elseClause = Option(ctx.elseClause()).map(visit)
CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan)
}

override def visitCaseWithoutExpression(ctx: RubyParser.CaseWithoutExpressionContext): RubyNode = {
val expression = None
val whenClauses = Option(ctx.whenClause().asScala).fold(List())(_.map(visit).toList)
val elseClause = Option(ctx.elseClause()).map(visit)
CaseExpression(expression, whenClauses, elseClause)(ctx.toTextSpan)
}

override def visitWhenClause(ctx: RubyParser.WhenClauseContext): RubyNode = {
val whenArgs = ctx.whenArgument()
val matchArgs =
Option(whenArgs.operatorExpressionList()).iterator.flatMap(_.operatorExpression().asScala).map(visit).toList
val matchSplatArg = Option(whenArgs.splattingArgument()).map(visit)
val thenClause = visit(ctx.thenClause())
WhenClause(matchArgs, matchSplatArg, thenClause)(ctx.toTextSpan)
}

override def visitSplattingArgument(ctx: SplattingArgumentContext): RubyNode = Unknown()(ctx.toTextSpan)

override def visitAssociationKey(ctx: RubyParser.AssociationKeyContext): RubyNode = {
if (Option(ctx.operatorExpression()).isDefined) {
visit(ctx.operatorExpression())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.Operators

class CaseTests extends RubyCode2CpgFixture {
"`case x ... end` should be represented with if-else chain and multiple match expressions should be or-ed together" in {
val caseCode = """
|case 0
| when 0
| 0
| when 1,2 then 1
| when 3, *[4,5] then 2
| when *[6] then 3
| else 4
|end
|""".stripMargin
val cpg = code(caseCode)

val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l

val List(assign) = block.astChildren.assignment.l;
val List(lhs, rhs) = assign.argument.l

List(lhs).isCall.name.l shouldBe List("<tmp-0>")
List(rhs).isLiteral.code.l shouldBe List("0")

val headIf @ List(_) = block.astChildren.isControlStructure.l
val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
val conds: List[List[String]] = ifStmts.condition.map { cond =>
val orConds = List(cond)
.repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)(
_.emit(_.whereNot(_.isCall.name(Operators.logicalOr)))
)
.l
orConds.map {
case u: Unknown => "unknown"
case mExpr =>
val call @ List(_) = List(mExpr).isCall.l
call.methodFullName.l shouldBe List("===")
val List(lhs, rhs) = call.argument.l
rhs.code shouldBe "<tmp-0>"
val List(code) = List(lhs).isCall.argument(1).code.l
code
}.l
}.l

conds shouldBe List(List("0"), List("1", "2"), List("3", "unknown"), List("unknown"))
val matchResults = ifStmts.astChildren.order(2).astChildren ++ ifStmts.last.astChildren.order(3).astChildren
matchResults.code.l shouldBe List("0", "1", "2", "3", "4")

// It's not ideal, but we choose the smallest containing text span that we have easily acesssible
// as we don't have a good way to immutably update RubyNode text spans.
ifStmts.code.l should contain only caseCode.trim
ifStmts.condition.map(_.code.trim).l shouldBe List("0", "when 1,2 then 1", "when 3, *[4,5] then 2", "*[6]")
}

"`case ... end` without expression" in {
val cpg = code("""
|case
| when false, true then 0
| when true then 1
| when false, *[false,false] then 2
| when *[false, true] then 3
|end
|""".stripMargin)

val block @ List(_) = cpg.method(":program").block.astChildren.isBlock.l

val headIf @ List(_) = block.astChildren.isControlStructure.l
val ifStmts @ List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
val conds: List[List[String]] = ifStmts.condition.map { cond =>
val orConds = List(cond)
.repeat(_.isCall.where(_.name(Operators.logicalOr)).argument)(
_.emit(_.whereNot(_.isCall.name(Operators.logicalOr)))
)
.l
orConds.map {
case u: Unknown => "unknown"
case c => c.code
}
}.l
conds shouldBe List(List("false", "true"), List("true"), List("false", "unknown"), List("unknown"))

val matchResults = ifStmts.astChildren.order(2).astChildren.l
matchResults.code.l shouldBe List("0", "1", "2", "3")

ifStmts.last.astChildren.order(3).l shouldBe List()
}
}

0 comments on commit 21558c7

Please sign in to comment.