Skip to content

Commit

Permalink
Add basic begin-rescue-else-ensure-end
Browse files Browse the repository at this point in the history
  • Loading branch information
badly-drawn-wizards committed Dec 20, 2023
1 parent 360b98e commit ec5dab7
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType}
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral, NewControlStructure}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, ControlStructureTypes}
import io.shiftleft.semanticcpg.language.NodeOrdering.nodeList

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand All @@ -27,12 +28,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: HashLiteral => astForHashLiteral(node)
case node: Association => astForAssociation(node)
case node: IfExpression => astForIfExpression(node)
case node: RescueExpression => astForRescueExpression(node)
case _ => astForUnknown(node)

protected def astForStaticLiteral(node: StaticLiteral): Ast = {
Ast(literalNode(node, code(node), node.typeFullName))
}


// Helper for nil literals to put in empty clauses
protected def astForNilLiteral: Ast = Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass)))
protected def astForNilBlock: Ast = blockAst(NewBlock(), List(astForNilLiteral))

protected def astForDynamicLiteral(node: DynamicLiteral): Ast = {
val fmtValueAsts = node.expressions.map {
case stmtList: StatementList if stmtList.size == 1 =>
Expand Down Expand Up @@ -230,8 +237,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
// We want to make sure there's always an «else» clause in a ternary operator.
// The default value is a `nil` literal.
val elseAsts_ = if (elseAsts.isEmpty) {
val nilLiteral = Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass)))
List(blockAst(NewBlock(), List(nilLiteral)))
List(astForNilBlock)
} else {
elseAsts
}
Expand All @@ -242,6 +248,33 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
foldIfExpression(builder)(node)
}

protected def astForRescueExpression(node: RescueExpression): Ast = {
val tryAst = astForStatementList(node.body.asStatementList)
val rescueAsts = node.rescueClauses
.map {
case x: RescueClause =>
// TODO: add exception assignment
astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
}
val elseAst = node.elseClause.map {
case x: ElseClause => astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
}
val ensureAst = node.ensureClause.map {
case x: EnsureClause => astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
}
tryCatchAst(
NewControlStructure()
.controlStructureType(ControlStructureTypes.TRY)
.code(code(node)),
tryAst,
rescueAsts ++ elseAst.toSeq,
ensureAst
)
}

protected def astForUnknown(node: RubyNode): Ast = {
val className = node.getClass.getSimpleName
val text = code(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
builder(node, conditionAst, thenAst, elseAsts)
}

private def astForThenClause(node: RubyNode): Ast = {
node match
case stmtList: StatementList => astForStatementList(stmtList)
case _ => astForStatementList(StatementList(List(node))(node.span))
}
private def astForThenClause(node: RubyNode): Ast = astForStatementList(node.asStatementList)

private def astsForElseClauses(
elsIfClauses: List[RubyNode],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,13 @@ object RubyIntermediateAst {
span: TextSpan
) extends RubyNode(span)

final case class RescueExpression(
body: RubyNode,
rescueClauses: List[RubyNode],
elseClause: Option[RubyNode],
ensureClause: Option[RubyNode]
)(span: TextSpan)
extends RubyNode(span)
final case class RescueExpression( body: RubyNode, rescueClauses: List[RubyNode], elseClause: Option[RubyNode], ensureClause: Option[RubyNode])(
span: TextSpan
) extends RubyNode(span)

final case class RescueClause(
exceptionClassList: Option[RubyNode],
assignment: Option[RubyNode],
thenClause: RubyNode
)(span: TextSpan)
extends RubyNode(span)
final case class RescueClause( exceptionClassList: Option[RubyNode], assignment: Option[RubyNode], thenClause: RubyNode)(
span: TextSpan
) extends RubyNode(span)

final case class EnsureClause(thenClause: RubyNode)(span: TextSpan) extends RubyNode(span)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import scala.jdk.CollectionConverters.*
import io.joern.rubysrc2cpg.parser.RubyParser.RescueClauseContext
import io.joern.rubysrc2cpg.parser.RubyParser.EnsureClauseContext
import io.joern.rubysrc2cpg.parser.RubyParser.ExceptionClassListContext
import org.antlr.v4.runtime.tree.RuleNode

/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST.
*/
class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {

override def defaultResult(): RubyNode = Unknown()(TextSpan(None, None, None, None, ""))
override protected def shouldVisitNextChild(node: RuleNode, currentResult: RubyNode): Boolean = currentResult.isInstanceOf[Unknown]

override def visit(tree: ParseTree): RubyNode = {
Option(tree).map(super.visit).getOrElse(defaultResult())
Expand Down Expand Up @@ -540,14 +542,19 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitBodyStatement(ctx: RubyParser.BodyStatementContext): RubyNode = {
if (ctx.rescueClause().isEmpty && Option(ctx.elseClause()).isEmpty && Option(ctx.ensureClause()).isEmpty) {
val body = visit(ctx.compoundStatement())
val rescueClauses = Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList)
val elseClause = Option(ctx.elseClause).map(visit)
val ensureClause = Option(ctx.ensureClause).map(visit)

if (rescueClauses.isEmpty && elseClause.isEmpty && ensureClause.isEmpty) {
visit(ctx.compoundStatement())
} else {
RescueExpression(
visit(ctx.compoundStatement()),
Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList),
Option(ctx.elseClause).map(visit),
Option(ctx.ensureClause).map(visit)
body,
rescueClauses,
elseClause,
ensureClause
)(ctx.toTextSpan)
}
}
Expand All @@ -558,10 +565,13 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitRescueClause(ctx: RescueClauseContext): RubyNode = {
val exceptionClassList = Option(ctx.exceptionClassList).map(visit)
val elseClause = Option(ctx.exceptionVariableAssignment).map(visit)
val thenClause = visit(ctx.thenClause)
RescueClause(
Option(ctx.exceptionClassList).map(visit),
Option(ctx.exceptionVariableAssignment).map(visit),
visit(ctx.thenClause)
exceptionClassList,
elseClause,
thenClause
)(ctx.toTextSpan)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes,Operators}
import io.shiftleft.codepropertygraph.generated.nodes.Call
import io.shiftleft.semanticcpg.language.*

Expand Down Expand Up @@ -254,4 +252,37 @@ class ControlStructureTests extends RubyCode2CpgFixture {
putsHi.lineNumber shouldBe Some(2)
}

"`begin ... rescue ... end is represented by a `TRY` CONTROL_STRUCTURE node" in {
val cpg = code("""
|begin
| 1
|rescue
| 2
|rescue
| 3
|else
| 4
|ensure
| 5
|end
|""".stripMargin)

val List(rescueNode) = cpg.tryBlock.l
rescueNode.controlStructureType shouldBe ControlStructureTypes.TRY
val List(body, rescueBody1, rescueBody2, elseBody, ensureBody) = rescueNode.astChildren.l
body.ast.isLiteral.code.l shouldBe List("1")
body.order shouldBe 1

rescueBody1.ast.isLiteral.code.l shouldBe List("2")
rescueBody1.order shouldBe 2

rescueBody2.ast.isLiteral.code.l shouldBe List("3")
rescueBody2.order shouldBe 2

elseBody.ast.isLiteral.code.l shouldBe List("4")
elseBody.order shouldBe 2

ensureBody.ast.isLiteral.code.l shouldBe List("5")
ensureBody.order shouldBe 3
}
}

0 comments on commit ec5dab7

Please sign in to comment.