From ec5dab738904389a40414430eee9809e98e93818 Mon Sep 17 00:00:00 2001 From: Reuben Steenekamp Date: Wed, 20 Dec 2023 19:13:19 +0200 Subject: [PATCH] Add basic begin-rescue-else-ensure-end --- .../AstForExpressionsCreator.scala | 41 +++++++++++++++++-- .../astcreation/AstForStatementsCreator.scala | 6 +-- .../astcreation/RubyIntermediateAst.scala | 19 +++------ .../rubysrc2cpg/parser/RubyNodeCreator.scala | 26 ++++++++---- .../querying/ControlStructureTests.scala | 37 +++++++++++++++-- 5 files changed, 96 insertions(+), 33 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 770cca94aca0..9b45eb2b437a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -4,8 +4,9 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.* import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType} import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral} -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewLiteral, NewControlStructure} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, ControlStructureTypes} +import io.shiftleft.semanticcpg.language.NodeOrdering.nodeList trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -27,12 +28,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case node: HashLiteral => astForHashLiteral(node) case node: Association => astForAssociation(node) case node: IfExpression => astForIfExpression(node) + case node: RescueExpression => astForRescueExpression(node) case _ => astForUnknown(node) protected def astForStaticLiteral(node: StaticLiteral): Ast = { Ast(literalNode(node, code(node), node.typeFullName)) } + + // Helper for nil literals to put in empty clauses + protected def astForNilLiteral: Ast = Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass))) + protected def astForNilBlock: Ast = blockAst(NewBlock(), List(astForNilLiteral)) + protected def astForDynamicLiteral(node: DynamicLiteral): Ast = { val fmtValueAsts = node.expressions.map { case stmtList: StatementList if stmtList.size == 1 => @@ -230,8 +237,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { // We want to make sure there's always an «else» clause in a ternary operator. // The default value is a `nil` literal. val elseAsts_ = if (elseAsts.isEmpty) { - val nilLiteral = Ast(NewLiteral().code("nil").typeFullName(getBuiltInType(Defines.NilClass))) - List(blockAst(NewBlock(), List(nilLiteral))) + List(astForNilBlock) } else { elseAsts } @@ -242,6 +248,33 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { foldIfExpression(builder)(node) } + protected def astForRescueExpression(node: RescueExpression): Ast = { + val tryAst = astForStatementList(node.body.asStatementList) + val rescueAsts = node.rescueClauses + .map { + case x: RescueClause => + // TODO: add exception assignment + astForStatementList(x.thenClause.asStatementList) + case x => astForUnknown(x) + } + val elseAst = node.elseClause.map { + case x: ElseClause => astForStatementList(x.thenClause.asStatementList) + case x => astForUnknown(x) + } + val ensureAst = node.ensureClause.map { + case x: EnsureClause => astForStatementList(x.thenClause.asStatementList) + case x => astForUnknown(x) + } + tryCatchAst( + NewControlStructure() + .controlStructureType(ControlStructureTypes.TRY) + .code(code(node)), + tryAst, + rescueAsts ++ elseAst.toSeq, + ensureAst + ) + } + protected def astForUnknown(node: RubyNode): Ast = { val className = node.getClass.getSimpleName val text = code(node) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 07a2b094d80b..22dc7d363b0f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -55,11 +55,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t builder(node, conditionAst, thenAst, elseAsts) } - private def astForThenClause(node: RubyNode): Ast = { - node match - case stmtList: StatementList => astForStatementList(stmtList) - case _ => astForStatementList(StatementList(List(node))(node.span)) - } + private def astForThenClause(node: RubyNode): Ast = astForStatementList(node.asStatementList) private def astsForElseClauses( elsIfClauses: List[RubyNode], diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 96be1b2238d6..9a449cfcb2fd 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -76,20 +76,13 @@ object RubyIntermediateAst { span: TextSpan ) extends RubyNode(span) - final case class RescueExpression( - body: RubyNode, - rescueClauses: List[RubyNode], - elseClause: Option[RubyNode], - ensureClause: Option[RubyNode] - )(span: TextSpan) - extends RubyNode(span) + final case class RescueExpression( body: RubyNode, rescueClauses: List[RubyNode], elseClause: Option[RubyNode], ensureClause: Option[RubyNode])( + span: TextSpan + ) extends RubyNode(span) - final case class RescueClause( - exceptionClassList: Option[RubyNode], - assignment: Option[RubyNode], - thenClause: RubyNode - )(span: TextSpan) - extends RubyNode(span) + final case class RescueClause( exceptionClassList: Option[RubyNode], assignment: Option[RubyNode], thenClause: RubyNode)( + span: TextSpan + ) extends RubyNode(span) final case class EnsureClause(thenClause: RubyNode)(span: TextSpan) extends RubyNode(span) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala index f5c5e3660b91..3562349a62f0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala @@ -10,12 +10,14 @@ import scala.jdk.CollectionConverters.* import io.joern.rubysrc2cpg.parser.RubyParser.RescueClauseContext import io.joern.rubysrc2cpg.parser.RubyParser.EnsureClauseContext import io.joern.rubysrc2cpg.parser.RubyParser.ExceptionClassListContext +import org.antlr.v4.runtime.tree.RuleNode /** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST. */ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { override def defaultResult(): RubyNode = Unknown()(TextSpan(None, None, None, None, "")) + override protected def shouldVisitNextChild(node: RuleNode, currentResult: RubyNode): Boolean = currentResult.isInstanceOf[Unknown] override def visit(tree: ParseTree): RubyNode = { Option(tree).map(super.visit).getOrElse(defaultResult()) @@ -540,14 +542,19 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitBodyStatement(ctx: RubyParser.BodyStatementContext): RubyNode = { - if (ctx.rescueClause().isEmpty && Option(ctx.elseClause()).isEmpty && Option(ctx.ensureClause()).isEmpty) { + val body = visit(ctx.compoundStatement()) + val rescueClauses = Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList) + val elseClause = Option(ctx.elseClause).map(visit) + val ensureClause = Option(ctx.ensureClause).map(visit) + + if (rescueClauses.isEmpty && elseClause.isEmpty && ensureClause.isEmpty) { visit(ctx.compoundStatement()) } else { RescueExpression( - visit(ctx.compoundStatement()), - Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList), - Option(ctx.elseClause).map(visit), - Option(ctx.ensureClause).map(visit) + body, + rescueClauses, + elseClause, + ensureClause )(ctx.toTextSpan) } } @@ -558,10 +565,13 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitRescueClause(ctx: RescueClauseContext): RubyNode = { + val exceptionClassList = Option(ctx.exceptionClassList).map(visit) + val elseClause = Option(ctx.exceptionVariableAssignment).map(visit) + val thenClause = visit(ctx.thenClause) RescueClause( - Option(ctx.exceptionClassList).map(visit), - Option(ctx.exceptionVariableAssignment).map(visit), - visit(ctx.thenClause) + exceptionClassList, + elseClause, + thenClause )(ctx.toTextSpan) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index 4662a60927fa..092e95a4cef9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -1,7 +1,5 @@ -package io.joern.rubysrc2cpg.querying - import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes,Operators} import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.semanticcpg.language.* @@ -254,4 +252,37 @@ class ControlStructureTests extends RubyCode2CpgFixture { putsHi.lineNumber shouldBe Some(2) } + "`begin ... rescue ... end is represented by a `TRY` CONTROL_STRUCTURE node" in { + val cpg = code(""" + |begin + | 1 + |rescue + | 2 + |rescue + | 3 + |else + | 4 + |ensure + | 5 + |end + |""".stripMargin) + + val List(rescueNode) = cpg.tryBlock.l + rescueNode.controlStructureType shouldBe ControlStructureTypes.TRY + val List(body, rescueBody1, rescueBody2, elseBody, ensureBody) = rescueNode.astChildren.l + body.ast.isLiteral.code.l shouldBe List("1") + body.order shouldBe 1 + + rescueBody1.ast.isLiteral.code.l shouldBe List("2") + rescueBody1.order shouldBe 2 + + rescueBody2.ast.isLiteral.code.l shouldBe List("3") + rescueBody2.order shouldBe 2 + + elseBody.ast.isLiteral.code.l shouldBe List("4") + elseBody.order shouldBe 2 + + ensureBody.ast.isLiteral.code.l shouldBe List("5") + ensureBody.order shouldBe 3 + } }