Skip to content

Commit

Permalink
[Ruby] do-end Blocks (with implicit &Proc and yield handling) #3928 (#…
Browse files Browse the repository at this point in the history
…4359)

* Ruby yield expressions

* Scalafmt

* Create FreshNameGenerator class

* Add tests and singleton methods

* scalafmt

* Address some PR comments

* Attempt 1 at return flow

* scalafmt

* Add test for yield argument

* scalafmt

* Make return for yield
  • Loading branch information
badly-drawn-wizards authored Mar 20, 2024
1 parent 9001c04 commit 4a12a27
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class AstCreator(
with AstForExpressionsCreator
with AstForFunctionsCreator
with AstForTypesCreator
with FreshVariableCreator
with AstSummaryVisitor
with AstNodeBuilder[RubyNode, AstCreator] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType}
import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators, PropertyNames}
import io.joern.rubysrc2cpg.utils.FreshNameGenerator

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

val tmpGen = FreshNameGenerator(i => s"<tmp-$i>")

protected def astForExpression(node: RubyNode): Ast = node match
case node: StaticLiteral => astForStaticLiteral(node)
case node: HereDocNode => astForHereDoc(node)
Expand All @@ -26,6 +29,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: SimpleCall => astForSimpleCall(node)
case node: RequireCall => astForRequireCall(node)
case node: IncludeCall => astForIncludeCall(node)
case node: YieldExpr => astForYield(node)
case node: RangeExpression => astForRange(node)
case node: ArrayLiteral => astForArrayLiteral(node)
case node: HashLiteral => astForHashLiteral(node)
Expand Down Expand Up @@ -197,7 +201,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val block = blockNode(node)
scope.pushNewScope(BlockScope(block))

val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(freshVariableName))
val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpGen.fresh))
def tmpIdentifier = {
val tmpAst = astForSimpleIdentifier(tmp)
tmpAst.root.collect { case x: NewIdentifier => x.typeFullName(receiverTypeFullName) }
Expand Down Expand Up @@ -365,6 +369,27 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
astForSimpleCall(node.asSimpleCall)
}

protected def astForYield(node: YieldExpr): Ast = {
scope.useProcParam match {
case Some(param) =>
val call = astForExpression(
SimpleCall(SimpleIdentifier()(node.span.spanStart(param)), node.arguments)(node.span)
)
val ret = returnAst(returnNode(node, code(node)))
val cond = astForExpression(
SimpleCall(SimpleIdentifier()(node.span.spanStart(tmpGen.fresh)), List())(node.span.spanStart("<nondet>"))
)
callAst(
callNode(node, code(node), Operators.conditional, Operators.conditional, DispatchTypes.STATIC_DISPATCH),
List(cond, call, ret)
)
case None =>
logger.warn(s"Yield expression outside of method scope: ${code(node)} ($relativeFileName), skipping")
astForUnknown(node)

}
}

protected def astForRange(node: RangeExpression): Ast = {
val lbAst = astForExpression(node.lowerBound)
val ubAst = astForExpression(node.upperBound)
Expand Down Expand Up @@ -398,7 +423,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

protected def astForHashLiteral(node: HashLiteral): Ast = {
val tmp = freshVariableName
val tmp = tmpGen.fresh

def tmpAst(tmpNode: Option[RubyNode] = None) = astForSimpleIdentifier(
SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import io.joern.x2cpg.utils.NodeBuilders.{newClosureBindingNode, newLocalNode, n
import io.joern.x2cpg.{Ast, AstEdge, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies, ModifierTypes, NodeTypes}
import io.joern.rubysrc2cpg.utils.FreshNameGenerator

trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

val procParamGen = FreshNameGenerator(i => Left(s"<proc-param-$i>"))

/** Creates method declaration related structures.
* @param node
* the node to create the AST structure from.
Expand Down Expand Up @@ -40,7 +43,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
)

if (methodName == XDefines.ConstructorMethodName) scope.pushNewScope(ConstructorScope(fullName))
else scope.pushNewScope(MethodScope(fullName))
else scope.pushNewScope(MethodScope(fullName, procParamGen.fresh))

val parameterAsts = astForParameters(node.parameters)

Expand Down Expand Up @@ -76,12 +79,19 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
astForMethodBody(node.body, optionalStatementList)
}

val anonProcParam = scope.anonProcParam.map { param =>
val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param"))
val nextIndex =
parameterAsts.lastOption.flatMap(_.root).map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(0)
astForParameter(paramNode, nextIndex)
}

scope.popScope()

val modifiers =
ModifierTypes.VIRTUAL :: (if isClosure then ModifierTypes.LAMBDA :: Nil else Nil) map newModifierNode

methodAst(method, parameterAsts, stmtBlockAst, methodReturn, modifiers) :: refs
methodAst(method, parameterAsts ++ anonProcParam, stmtBlockAst, methodReturn, modifiers) :: refs
}

private def transformAsClosureBody(refs: List[Ast], baseStmtBlockAst: Ast) = {
Expand Down Expand Up @@ -141,6 +151,19 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
)
scope.addToScope(node.name, parameterIn)
Ast(parameterIn)
case node: ProcParameter =>
val parameterIn = parameterInNode(
node = node,
name = node.name,
code = code(node),
index = index,
isVariadic = false,
evaluationStrategy = EvaluationStrategies.BY_REFERENCE,
typeFullName = None
)
scope.addToScope(node.name, parameterIn)
scope.setProcParam(node.name)
Ast(parameterIn)
case node: CollectionParameter =>
val typeFullName = node match {
case ArrayParameter(_) => prefixAsBuiltin("Array")
Expand Down Expand Up @@ -252,7 +275,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
astParentFullName = scope.surroundingScopeFullName
)

scope.pushNewScope(MethodScope(fullName))
scope.pushNewScope(MethodScope(fullName, procParamGen.fresh))

val thisParameterAst = Ast(
newThisParameterNode(
Expand All @@ -268,8 +291,20 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

val stmtBlockAst = astForMethodBody(node.body, optionalStatementList)

val anonProcParam = scope.anonProcParam.map { param =>
val paramNode = ProcParameter(param)(node.span.spanStart(s"&$param"))
val nextIndex =
parameterAsts.lastOption.flatMap(_.root).map { case m: NewMethodParameterIn => m.index + 1 }.getOrElse(1)
astForParameter(paramNode, nextIndex)
}

scope.popScope()
methodAst(method, thisParameterAst +: parameterAsts, stmtBlockAst, methodReturnNode(node, Defines.Any))
methodAst(
method,
(thisParameterAst +: parameterAsts) ++ anonProcParam,
stmtBlockAst,
methodReturnNode(node, Defines.Any)
)

case targetNode =>
logger.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
}
def generatedNode: StatementList = node.expression
.map { e =>
val tmp = SimpleIdentifier(None)(e.span.spanStart(freshVariableName))
val tmp = SimpleIdentifier(None)(e.span.spanStart(tmpGen.fresh))
StatementList(
List(SingleAssignment(tmp, "=", e)(e.span)) ++
goCase(Some(tmp))
Expand Down Expand Up @@ -252,7 +252,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case node: MemberCallWithBlock => returnAstForRubyCall(node)
case node: SimpleCallWithBlock => returnAstForRubyCall(node)
case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | IndexAccess | Association |
RubyCall) =>
YieldExpr | RubyCall) =>
astForReturnStatement(ReturnExpression(List(node))(node.span)) :: Nil
case node: SingleAssignment =>
astForSingleAssignment(node) :: List(astForReturnStatement(ReturnExpression(List(node.lhs))(node.span)))
Expand All @@ -265,6 +265,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case ret: ReturnExpression => astForReturnStatement(ret) :: Nil
case node: MethodDeclaration =>
(astForMethodDeclaration(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList

case node =>
logger.warn(
s"Implicit return here not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:
astParentType = scope.surroundingAstLabel,
astParentFullName = scope.surroundingScopeFullName
)
scope.pushNewScope(MethodScope(fullName))
scope.pushNewScope(MethodScope(fullName, procParamGen.fresh))
val block_ = blockNode(node)
scope.pushNewScope(BlockScope(block_))
// TODO: Should it be `return this.@abc`?
Expand Down Expand Up @@ -155,7 +155,7 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:
astParentType = scope.surroundingAstLabel,
astParentFullName = scope.surroundingScopeFullName
)
scope.pushNewScope(MethodScope(fullName))
scope.pushNewScope(MethodScope(fullName, procParamGen.fresh))
val parameter = parameterInNode(node, "x", "x", 1, false, EvaluationStrategies.BY_REFERENCE)
val methodBody = {
val block_ = blockNode(node)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ object RubyIntermediateAst {

final case class HashParameter(name: String)(span: TextSpan) extends RubyNode(span) with CollectionParameter

final case class ProcParameter(target: RubyNode)(span: TextSpan) extends RubyNode(span) with MethodParameter {
def name: String = target.text
}
final case class ProcParameter(name: String)(span: TextSpan) extends RubyNode(span) with MethodParameter

final case class SingleAssignment(lhs: RubyNode, op: String, rhs: RubyNode)(span: TextSpan) extends RubyNode(span)

Expand Down Expand Up @@ -303,6 +301,8 @@ object RubyIntermediateAst {
*/
final case class ProcOrLambdaExpr(block: Block)(span: TextSpan) extends RubyNode(span)

final case class YieldExpr(arguments: List[RubyNode])(span: TextSpan) extends RubyNode(span)

/** Represents a call with a block argument.
*/
sealed trait RubyCallWithBlock[C <: RubyCall] extends RubyCall {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import better.files.File
import io.joern.rubysrc2cpg.astcreation.GlobalTypes
import io.joern.rubysrc2cpg.astcreation.GlobalTypes.builtinPrefix
import io.joern.x2cpg.Defines
import io.joern.rubysrc2cpg.passes.Defines as RDefines
import io.joern.x2cpg.datastructures.*
import io.shiftleft.codepropertygraph.generated.NodeTypes
import io.shiftleft.codepropertygraph.generated.nodes.{DeclarationNew, NewLocal, NewMethodParameterIn}
Expand Down Expand Up @@ -117,6 +118,46 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String])
case ScopeElement(x: MethodLikeScope, _) => x.fullName
}

/** Locates a position in the stack matching a partial function, modifies it and emits a result
* @param pf
* Tests ScopeElements of the stack. If they match, return the new value and the result to emi
* @return
* the emitted result if the position was found and modifies
*/
def updateSurrounding[T](
pf: PartialFunction[
ScopeElement[String, DeclarationNew, TypedScopeElement],
(ScopeElement[String, DeclarationNew, TypedScopeElement], T)
]
): Option[T] = {
stack.zipWithIndex
.collectFirst { case (pf(elem, res), i) =>
(elem, res, i)
}
.map { case (elem, res, i) =>
stack = stack.updated(i, elem)
res
}
}

/** Get the name of the implicit or explict proc param and mark the method scope as using the proc param
*/
def useProcParam: Option[String] = updateSurrounding {
case ScopeElement(MethodScope(fullName, param, _), variables) =>
(ScopeElement(MethodScope(fullName, param, true), variables), param.fold(x => x, x => x))
}

/** Get the name of the implicit or explict proc param */
def anonProcParam: Option[String] = stack.collectFirst { case ScopeElement(MethodScope(_, Left(param), true), _) =>
param
}

/** Set the name of explict proc param */
def setProcParam(param: String): Unit = updateSurrounding {
case ScopeElement(MethodScope(fullName, _, _), variables) =>
(ScopeElement(MethodScope(fullName, Right(param)), variables), ())
}

def surroundingTypeFullName: Option[String] = stack.collectFirst { case ScopeElement(x: TypeLikeScope, _) =>
x.fullName
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ trait MethodLikeScope extends TypedScopeElement {
def fullName: String
}

case class MethodScope(fullName: String) extends MethodLikeScope
case class MethodScope(fullName: String, procParam: Either[String, String], hasYield: Boolean = false)
extends MethodLikeScope

case class ConstructorScope(fullName: String) extends MethodLikeScope

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,21 @@ package io.joern.rubysrc2cpg.parser

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.*
import io.joern.rubysrc2cpg.parser.AntlrContextHelpers.*
import io.joern.rubysrc2cpg.parser.RubyParser.RangeOperatorContext
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode}
import io.joern.x2cpg.Defines as XDefines;

import scala.jdk.CollectionConverters.*
import io.joern.rubysrc2cpg.utils.FreshNameGenerator

/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST.
*/
class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {

private var classCounter: Int = 0

private def tmpClassTemplate(id: Int): String = s"<anon-class-$id>"

private val classNameGen = FreshNameGenerator(id => s"<anon-class-$id>")
protected def freshClassName(span: TextSpan): SimpleIdentifier = {
val name = tmpClassTemplate(classCounter)
classCounter += 1
SimpleIdentifier(None)(span.spanStart(name))
SimpleIdentifier(None)(span.spanStart(classNameGen.fresh))
}

private def defaultTextSpan(code: String = ""): TextSpan = TextSpan(None, None, None, None, code)
Expand Down Expand Up @@ -523,6 +518,18 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}
}

override def visitYieldExpression(ctx: RubyParser.YieldExpressionContext): RubyNode = {
val arguments = Option(ctx.argumentWithParentheses()).iterator.flatMap(_.arguments).map(visit).toList
YieldExpr(arguments)(ctx.toTextSpan)
}

override def visitYieldMethodInvocationWithoutParentheses(
ctx: RubyParser.YieldMethodInvocationWithoutParenthesesContext
): RubyNode = {
val arguments = ctx.primaryValueList().primaryValue().asScala.map(visit).toList
YieldExpr(arguments)(ctx.toTextSpan)
}

override def visitConstantIdentifierVariable(ctx: RubyParser.ConstantIdentifierVariableContext): RubyNode = {
SimpleIdentifier()(ctx.toTextSpan)
}
Expand Down Expand Up @@ -870,7 +877,9 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitProcParameter(ctx: RubyParser.ProcParameterContext): RubyNode = {
ProcParameter(visit(ctx.procParameterName()))(ctx.toTextSpan)
ProcParameter(
Option(ctx.procParameterName).map(_.LOCAL_VARIABLE_IDENTIFIER()).map(_.getText()).getOrElse(ctx.getText())
)(ctx.toTextSpan)
}

override def visitHashParameter(ctx: RubyParser.HashParameterContext): RubyNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ object Defines {

val Resolver: String = "<dependency-resolver>"

val AnonymousProcParameter = "<anonymous-proc-param>"

def getBuiltInType(typeInString: String) = s"${GlobalTypes.builtinPrefix}.$typeInString"

object RubyOperators {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.joern.rubysrc2cpg.utils

class FreshNameGenerator[T](template: Int => T) {
private var counter: Int = 0
def fresh: T = {
val name = template(counter)
counter += 1
name
}
}
Loading

0 comments on commit 4a12a27

Please sign in to comment.