Skip to content

Commit

Permalink
Add tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
badly-drawn-wizards committed Jan 11, 2024
1 parent 74301e5 commit 58efe62
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,22 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
// There may be a splat as the last match expression, which is currently parsed as unknown
// A single match expression is compared using `.===` to case target expression if it is present
// otherwise it is treated as a conditional.
val conditions = whenClause.matchExpressions.map {
case u: Unknown => u
case mExpr => expr.map(e => MemberCall(e, ".", "===", List(mExpr))(mExpr.span)).getOrElse(mExpr)
}
// There is always at least one match expression
val conditions = whenClause.matchExpressions.map { mExpr =>
expr.map(e => MemberCall(mExpr, ".", "===", List(e))(mExpr.span)).getOrElse(mExpr)
} ++ (whenClause.matchSplatExpression.iterator.flatMap {
case u: Unknown => List(u)
case e =>
logger.warn("Splatting not implemented for `when` in ruby `case`")
List(Unknown()(e.span))
})
// There is always at least one match expression or a splat
// a splat will become an unknown in condition at the end
val condition = conditions.init.foldRight(conditions.last) { (cond, condAcc) =>
BinaryExpression(cond, "||", condAcc)(whenClause.span)
}
val conditional = IfExpression(
condition,
whenClause.thenClause.map(_.asStatementList).getOrElse(StatementList(List())(whenClause.span)),
whenClause.thenClause.asStatementList,
List(),
restClause.map { els => ElseClause(els.asStatementList)(els.span) }
)(node.span)
Expand All @@ -133,7 +138,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
.map { e =>
val tmp = SimpleIdentifier(None)(e.span.spanStart(freshName))
StatementList(
List(SingleAssignment(e, "=", tmp)(e.span)) ++
List(SingleAssignment(tmp, "=", e)(e.span)) ++
goCase(Some(tmp))
)(node.span)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object RubyIntermediateAst {
def asStatementList = node match
case stmtList: StatementList => stmtList
case _ => StatementList(List(node))(node.span)

}

final case class Unknown()(span: TextSpan) extends RubyNode(span)
Expand Down Expand Up @@ -126,7 +127,11 @@ object RubyIntermediateAst {
)(span: TextSpan)
extends RubyNode(span)

final case class WhenClause(matchExpressions: List[RubyNode], thenClause: Option[RubyNode])(span: TextSpan)
final case class WhenClause(
matchExpressions: List[RubyNode],
matchSplatExpression: Option[RubyNode],
thenClause: RubyNode
)(span: TextSpan)
extends RubyNode(span)

final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.antlr.v4.runtime.tree.{ErrorNode, ParseTree, TerminalNode}

import scala.jdk.CollectionConverters.*
import org.antlr.v4.runtime.tree.RuleNode
import io.joern.rubysrc2cpg.parser.RubyParser.SplattingArgumentContext

/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST.
*/
Expand Down Expand Up @@ -583,16 +584,16 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitWhenClause(ctx: RubyParser.WhenClauseContext): RubyNode = {
val matchArgs = Option(ctx.whenArgument()).iterator
.flatMap(arg =>
Option(arg.operatorExpressionList()).iterator.flatMap(_.operatorExpression().asScala).map(visit) ++
Option(arg.splattingArgument()).map(visit).iterator
)
.toList
val thenClause = Option(ctx.thenClause()).map(visit)
WhenClause(matchArgs, thenClause)(ctx.toTextSpan)
val whenArgs = ctx.whenArgument()
val matchArgs =
Option(whenArgs.operatorExpressionList()).iterator.flatMap(_.operatorExpression().asScala).map(visit).toList
val matchSplatArg = Option(whenArgs.splattingArgument()).map(visit)
val thenClause = visit(ctx.thenClause())
WhenClause(matchArgs, matchSplatArg, thenClause)(ctx.toTextSpan)
}

override def visitSplattingArgument(ctx: SplattingArgumentContext): RubyNode = Unknown()(ctx.toTextSpan)

override def visitAssociationKey(ctx: RubyParser.AssociationKeyContext): RubyNode = {
if (Option(ctx.operatorExpression()).isDefined) {
visit(ctx.operatorExpression())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,75 @@ package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.codepropertygraph.generated.nodes.*

class CaseTests extends RubyCode2CpgFixture {
"`case x ... end` should be represented with" in {
"`case x ... end` should be represented with if-else chain" in {
val cpg = code("""
|case 0
| when 0
| 0
| when 1 then 1
| when 1,2 then 1
| when 3, *[4,5] then 2
| when *[6] then 3
| else 4
|end
|""".stripMargin)

cpg.method(":program").dotAst.foreach(println)
val block@List(_) = cpg.method(":program").block.astChildren.isBlock.l

val List(assign) = block.astChildren.assignment.l;
val List(lhs, rhs) = assign.argument.l

List(lhs).isCall.name.l shouldBe List("<tmp-0>")
List(rhs).isLiteral.code.l shouldBe List("0")

val headIf@List(_) = block.astChildren.isControlStructure.l
val ifStmts@List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
val conds: List[List[String]] = ifStmts.condition.map { cond =>
val orConds = List(cond).repeat(_.isCall.where(_.name("<operator>.logicalOr")).argument)(_.emit(_.whereNot(_.isCall.name("<operator>.logicalOr")))).l
orConds.map {
case u: Unknown => "unknown"
case mExpr =>
val call@List(_) = List(mExpr).isCall.l
call.methodFullName.l shouldBe List("===")
val List(lhs, rhs) = call.argument.l
rhs.code shouldBe "<tmp-0>"
val List(code) = List(lhs).isCall.argument(1).code.l
code
}.l
}.l
conds shouldBe List(List("0"), List("1", "2"), List("3", "unknown"), List("unknown"))
val matchResults = ifStmts.astChildren.order(2).astChildren ++ ifStmts.last.astChildren.order(3).astChildren
matchResults.code.l shouldBe List("0", "1", "2", "3", "4")
}

"`case ... end` without expression" in {
val cpg = code("""
|case
| when false, true then 0
| when true then 1
| when false, *[false,false] then 2
| when *[false, true] then 3
|end
|""".stripMargin)

}
val block@List(_) = cpg.method(":program").block.astChildren.isBlock.l

"case with else" in {
val cpg = code("""
|case 0
| when 1 then 1
| else 0
|end
|""".stripMargin)
val headIf@List(_) = block.astChildren.isControlStructure.l
val ifStmts@List(_, _, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
val conds: List[List[String]] = ifStmts.condition.map { cond =>
val orConds = List(cond).repeat(_.isCall.where(_.name("<operator>.logicalOr")).argument)(_.emit(_.whereNot(_.isCall.name("<operator>.logicalOr")))).l
orConds.map {
case u: Unknown => "unknown"
case c => c.code
}
}.l
conds shouldBe List(List("false", "true"), List("true"), List("false", "unknown"), List("unknown"))

val matchResults = ifStmts.astChildren.order(2).astChildren.l
matchResults.code.l shouldBe List("0", "1", "2", "3")

ifStmts.last.astChildren.order(3).l shouldBe List()
}
}

0 comments on commit 58efe62

Please sign in to comment.