Skip to content

Commit

Permalink
Revert "[rubysrc2cpg] General Do-Block Function Fixes (joernio#3676)"
Browse files Browse the repository at this point in the history
This reverts commit 4ef5cdc.
  • Loading branch information
khemrajrathore committed Oct 10, 2023
1 parent 22bdb74 commit a02648b
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,7 @@ class AstCreator(
programCtx.compoundStatement() != null &&
programCtx.compoundStatement().statements() != null
) {
val stmts = astForStatements(programCtx.compoundStatement().statements(), false, false)
val methodNodes = stmts.flatMap(_.nodes).collect { case x: NewMethod => x }.toSet
// Block methods is largely unnecessary, but will keep it for AST functions that still populate it instead of
// attaching it to AST
stmts ++ blockMethods.filterNot(_.root.collect { case x: NewMethod => x }.exists(methodNodes.contains))
astForStatements(programCtx.compoundStatement().statements(), false, false) ++ blockMethods
} else {
logger.error(s"File $filename has no compound statement. Needs to be examined")
List[Ast](Ast())
Expand Down Expand Up @@ -143,7 +139,23 @@ class AstCreator(
.filterNot(_.astParentType == NodeTypes.TYPE_DECL)
.map { methodNode =>
// Create a methodRefNode and assign it to the identifier version of the method, which will help in type propagation to resolve calls
methodRefAssignmentFromMethod(methodNode, Option(lineColNum), Option(lineColNum))
val methodRefNode = NewMethodRef()
.code("def " + methodNode.name + "(...)")
.methodFullName(methodNode.fullName)
.typeFullName(methodNode.fullName)
.lineNumber(lineColNum)
.columnNumber(lineColNum)

val methodNameIdentifier = NewIdentifier()
.code(methodNode.name)
.name(methodNode.name)
.typeFullName(Defines.Any)
.lineNumber(lineColNum)
.columnNumber(lineColNum)
scope.addToScope(methodNode.name, methodNameIdentifier)
val methodRefAssignmentAst =
astForAssignment(methodNameIdentifier, methodRefNode, methodNode.lineNumber, methodNode.columnNumber)
methodRefAssignmentAst
}
.toList

Expand Down Expand Up @@ -330,28 +342,29 @@ class AstCreator(
lineEnd(ctx).head,
columnEnd(ctx).head
)
val blockMethodNode =
blockMethodAsts.head.nodes.head
.asInstanceOf[NewMethod]

blockMethodAsts.foreach { ast =>
ast.root match
case Some(_: NewMethod) => blockMethods.addOne(ast)
case _ =>
}
blockMethods.addOne(blockMethodAsts.head)

blockMethodAsts :+ blockMethodAsts
.flatMap(_.nodes)
.collectFirst { case methodRefNode: NewMethodRef =>
val callNode = NewCall()
.name(blockName)
.methodFullName(methodRefNode.methodFullName)
.typeFullName(Defines.Any)
.code(methodRefNode.code)
.lineNumber(methodRefNode.lineNumber)
.columnNumber(methodRefNode.columnNumber)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
// TODO: primaryAst.headOption is broken when primaryAst is an array
callAst(callNode, argsAst ++ Seq(Ast(methodRefNode.copy)), primaryAst.headOption)
}
.getOrElse(Ast())
val callNode = NewCall()
.name(blockName)
.methodFullName(blockMethodNode.fullName)
.typeFullName(Defines.Any)
.code(blockMethodNode.code)
.lineNumber(blockMethodNode.lineNumber)
.columnNumber(blockMethodNode.columnNumber)
.dispatchType(DispatchTypes.STATIC_DISPATCH)

val methodRefNode = NewMethodRef()
.methodFullName(blockMethodNode.fullName)
.typeFullName(Defines.Any)
.code(blockMethodNode.code)
.lineNumber(blockMethodNode.lineNumber)
.columnNumber(blockMethodNode.columnNumber)

Seq(callAst(callNode, argsAst ++ Seq(Ast(methodRefNode)), primaryAst.headOption))
} else {
val callNode = methodNameAst.head.nodes
.filter(node => node.isInstanceOf[NewCall])
Expand Down Expand Up @@ -409,15 +422,15 @@ class AstCreator(
val baseAst = astForPrimaryContext(ctx.primary())

val blocksAst = if (ctx.block() != null) {
astForBlock(ctx.block())
Seq(astForBlock(ctx.block()))
} else {
Seq()
}
val callNode = methodNameAst.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall]
callNode
.code(text(ctx))
.lineNumber(ctx.COLON2.lineNumber)
.columnNumber(ctx.COLON2.columnNumber)
.lineNumber(ctx.COLON2().getSymbol().getLine())
.columnNumber(ctx.COLON2().getSymbol().getCharPositionInLine())
Seq(callAst(callNode, baseAst ++ blocksAst))
}

Expand Down Expand Up @@ -521,8 +534,8 @@ class AstCreator(
.asInstanceOf[NewCall]
.name

val isYieldMethod = if (blockName.endsWith(Defines.YIELD_SUFFIX)) {
val lookupMethodName = blockName.take(blockName.length - Defines.YIELD_SUFFIX.length)
val isYieldMethod = if (blockName.endsWith(YIELD_SUFFIX)) {
val lookupMethodName = blockName.take(blockName.length - YIELD_SUFFIX.length)
methodNamesWithYield.contains(lookupMethodName)
} else {
false
Expand All @@ -542,7 +555,7 @@ class AstCreator(
columnEnd(ctx).head
)
} else {
val blockAst = astForBlock(ctx.block())
val blockAst = Seq(astForBlock(ctx.block()))
// this is expected to be a call node
val callNode = methodIdAst.head.nodes.head.asInstanceOf[NewCall]
Seq(callAst(callNode, blockAst))
Expand All @@ -556,26 +569,18 @@ class AstCreator(
callNode.name(resolveAlias(callNode.name))

if (ctx.block() != null) {
val isYieldMethod = if (callNode.name.endsWith(Defines.YIELD_SUFFIX)) {
val lookupMethodName = callNode.name.take(callNode.name.length - Defines.YIELD_SUFFIX.length)
val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) {
val lookupMethodName = callNode.name.take(callNode.name.length - YIELD_SUFFIX.length)
methodNamesWithYield.contains(lookupMethodName)
} else {
false
}
if (isYieldMethod) {
val methAst = astForBlock(ctx.block(), Some(callNode.name))
methAst
.collectFirst { case x: Ast if x.root.isDefined && x.root.get.isInstanceOf[NewMethod] => x }
.foreach(blockMethods.addOne)
blockMethods.addOne(methAst)
Seq(callAst(callNode, parenAst))
} else if (callNode.name == Defines.DEFINE_METHOD) {
parenAst.headOption
.flatMap(_.root)
.collect { case x: AstNodeNew => stripQuotes(x.code).stripPrefix(":") }
.map(methodName => astForBlock(ctx.block(), Option(methodName)))
.getOrElse(Seq.empty)
} else {
val blockAst = astForBlock(ctx.block())
val blockAst = Seq(astForBlock(ctx.block()))
Seq(callAst(callNode, parenAst ++ blockAst))
}
} else
Expand All @@ -584,7 +589,7 @@ class AstCreator(

def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = {
val name = if (isYieldBlock) {
s"${resolveAlias(text(ctx))}${Defines.YIELD_SUFFIX}"
s"${resolveAlias(text(ctx))}$YIELD_SUFFIX"
} else {
val calleeName = resolveAlias(text(ctx))
// Add the call name to the global builtIn callNames set
Expand Down Expand Up @@ -659,27 +664,27 @@ class AstCreator(
private def astForCommandWithDoBlockContext(ctx: CommandWithDoBlockContext): Seq[Ast] = ctx match {
case ctx: ArgsAndDoBlockCommandWithDoBlockContext =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAst = astForDoBlock(ctx.doBlock())
val doBlockAst = Seq(astForDoBlock(ctx.doBlock()))
argsAsts ++ doBlockAst
case ctx: RubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext =>
val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx))
methodIdAsts.headOption.flatMap(_.root) match
case Some(methodIdRoot: NewCall) if methodIdRoot.name == Defines.DEFINE_METHOD =>
case Some(methodIdRoot: NewCall) if methodIdRoot.name == "define_method" =>
ctx.argumentsWithoutParentheses.arguments.argument.asScala.headOption
.map { methodArg =>
// TODO: methodArg will name the method, but this could be an identifier or even a string concatenation
// which is not assumed below
val methodName = stripQuotes(methodArg.getText).stripPrefix(":")
astForDoBlock(ctx.doBlock(), Option(methodName))
val methodName = stripQuotes(methodArg.getText)
Seq(astForDoBlock(ctx.doBlock(), Option(methodName)))
}
.getOrElse(Seq.empty)
case _ =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAsts = astForDoBlock(ctx.doBlock())
val doBlockAsts = Seq(astForDoBlock(ctx.doBlock()))
methodIdAsts ++ argsAsts ++ doBlockAsts
case ctx: RubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAsts = astForDoBlock(ctx.doBlock())
val doBlockAsts = Seq(astForDoBlock(ctx.doBlock()))
val methodNameAsts = astForMethodNameContext(ctx.methodName())
val primaryAsts = astForPrimaryContext(ctx.primary())
primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -196,20 +196,12 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
.lineNumber(ctx.op.getLine)
.columnNumber(ctx.op.getCharPositionInLine)
if (leftAst.size == 1 && rightAst.size > 1) {
if (rightAst.headOption.flatMap(_.root).exists(_.isInstanceOf[NewMethod])) {
/*
* Here we expect to be assigned the result of some dynamically defined function's application to some variable
*/
val lastAst = rightAst.takeRight(1)
rightAst.filterNot(_ == lastAst.head) ++ Seq(callAst(opCallNode, leftAst ++ lastAst))
} else {
/*
* This is multiple RHS packed into a single LHS. That is, packing left hand side.
* This is as good as multiple RHS packed into an array and put into a single LHS
*/
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
Seq(callAst(opCallNode, leftAst ++ packedRHS))
}
/*
* This is multiple RHS packed into a single LHS. That is, packing left hand side.
* This is as good as multiple RHS packed into an array and put into a single LHS
*/
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
Seq(callAst(opCallNode, leftAst ++ packedRHS))
} else {
Seq(callAst(opCallNode, leftAst ++ rightAst))
}
Expand Down Expand Up @@ -382,8 +374,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = {
val args = argumentsCtx.map(astForArguments).getOrElse(Seq())
val call =
callNode(ctx, text(ctx), Defines.UNRESOLVED_YIELD, Defines.UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH)
val call = callNode(ctx, text(ctx), UNRESOLVED_YIELD, UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH)
callAst(call, args)
}

Expand Down
Loading

0 comments on commit a02648b

Please sign in to comment.