Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rubysrc2cpg] General Do-Block Function Fixes #3676

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ class AstCreator(
programCtx.compoundStatement() != null &&
programCtx.compoundStatement().statements() != null
) {
astForStatements(programCtx.compoundStatement().statements(), false, false) ++ blockMethods
val stmts = astForStatements(programCtx.compoundStatement().statements(), false, false)
val methodNodes = stmts.flatMap(_.nodes).collect { case x: NewMethod => x }.toSet
// Block methods is largely unnecessary, but will keep it for AST functions that still populate it instead of
// attaching it to AST
stmts ++ blockMethods.filterNot(_.root.collect { case x: NewMethod => x }.exists(methodNodes.contains))
} else {
logger.error(s"File $filename has no compound statement. Needs to be examined")
List[Ast](Ast())
Expand Down Expand Up @@ -139,23 +143,7 @@ class AstCreator(
.filterNot(_.astParentType == NodeTypes.TYPE_DECL)
.map { methodNode =>
// Create a methodRefNode and assign it to the identifier version of the method, which will help in type propagation to resolve calls
val methodRefNode = NewMethodRef()
.code("def " + methodNode.name + "(...)")
.methodFullName(methodNode.fullName)
.typeFullName(methodNode.fullName)
.lineNumber(lineColNum)
.columnNumber(lineColNum)

val methodNameIdentifier = NewIdentifier()
.code(methodNode.name)
.name(methodNode.name)
.typeFullName(Defines.Any)
.lineNumber(lineColNum)
.columnNumber(lineColNum)
scope.addToScope(methodNode.name, methodNameIdentifier)
val methodRefAssignmentAst =
astForAssignment(methodNameIdentifier, methodRefNode, methodNode.lineNumber, methodNode.columnNumber)
methodRefAssignmentAst
methodRefAssignmentFromMethod(methodNode, Option(lineColNum), Option(lineColNum))
}
.toList

Expand Down Expand Up @@ -342,29 +330,28 @@ class AstCreator(
lineEnd(ctx).head,
columnEnd(ctx).head
)
val blockMethodNode =
blockMethodAsts.head.nodes.head
.asInstanceOf[NewMethod]

blockMethods.addOne(blockMethodAsts.head)

val callNode = NewCall()
.name(blockName)
.methodFullName(blockMethodNode.fullName)
.typeFullName(Defines.Any)
.code(blockMethodNode.code)
.lineNumber(blockMethodNode.lineNumber)
.columnNumber(blockMethodNode.columnNumber)
.dispatchType(DispatchTypes.STATIC_DISPATCH)

val methodRefNode = NewMethodRef()
.methodFullName(blockMethodNode.fullName)
.typeFullName(Defines.Any)
.code(blockMethodNode.code)
.lineNumber(blockMethodNode.lineNumber)
.columnNumber(blockMethodNode.columnNumber)
blockMethodAsts.foreach { ast =>
ast.root match
case Some(_: NewMethod) => blockMethods.addOne(ast)
case _ =>
}

Seq(callAst(callNode, argsAst ++ Seq(Ast(methodRefNode)), primaryAst.headOption))
blockMethodAsts :+ blockMethodAsts
.flatMap(_.nodes)
.collectFirst { case methodRefNode: NewMethodRef =>
val callNode = NewCall()
.name(blockName)
.methodFullName(methodRefNode.methodFullName)
.typeFullName(Defines.Any)
.code(methodRefNode.code)
.lineNumber(methodRefNode.lineNumber)
.columnNumber(methodRefNode.columnNumber)
.dispatchType(DispatchTypes.STATIC_DISPATCH)
// TODO: primaryAst.headOption is broken when primaryAst is an array
callAst(callNode, argsAst ++ Seq(Ast(methodRefNode.copy)), primaryAst.headOption)
}
.getOrElse(Ast())
} else {
val callNode = methodNameAst.head.nodes
.filter(node => node.isInstanceOf[NewCall])
Expand Down Expand Up @@ -422,15 +409,15 @@ class AstCreator(
val baseAst = astForPrimaryContext(ctx.primary())

val blocksAst = if (ctx.block() != null) {
Seq(astForBlock(ctx.block()))
astForBlock(ctx.block())
} else {
Seq()
}
val callNode = methodNameAst.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall]
callNode
.code(text(ctx))
.lineNumber(ctx.COLON2().getSymbol().getLine())
.columnNumber(ctx.COLON2().getSymbol().getCharPositionInLine())
.lineNumber(ctx.COLON2.lineNumber)
.columnNumber(ctx.COLON2.columnNumber)
Seq(callAst(callNode, baseAst ++ blocksAst))
}

Expand Down Expand Up @@ -534,8 +521,8 @@ class AstCreator(
.asInstanceOf[NewCall]
.name

val isYieldMethod = if (blockName.endsWith(YIELD_SUFFIX)) {
val lookupMethodName = blockName.take(blockName.length - YIELD_SUFFIX.length)
val isYieldMethod = if (blockName.endsWith(Defines.YIELD_SUFFIX)) {
val lookupMethodName = blockName.take(blockName.length - Defines.YIELD_SUFFIX.length)
methodNamesWithYield.contains(lookupMethodName)
} else {
false
Expand All @@ -555,7 +542,7 @@ class AstCreator(
columnEnd(ctx).head
)
} else {
val blockAst = Seq(astForBlock(ctx.block()))
val blockAst = astForBlock(ctx.block())
// this is expected to be a call node
val callNode = methodIdAst.head.nodes.head.asInstanceOf[NewCall]
Seq(callAst(callNode, blockAst))
Expand All @@ -569,18 +556,26 @@ class AstCreator(
callNode.name(resolveAlias(callNode.name))

if (ctx.block() != null) {
val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) {
val lookupMethodName = callNode.name.take(callNode.name.length - YIELD_SUFFIX.length)
val isYieldMethod = if (callNode.name.endsWith(Defines.YIELD_SUFFIX)) {
val lookupMethodName = callNode.name.take(callNode.name.length - Defines.YIELD_SUFFIX.length)
methodNamesWithYield.contains(lookupMethodName)
} else {
false
}
if (isYieldMethod) {
val methAst = astForBlock(ctx.block(), Some(callNode.name))
blockMethods.addOne(methAst)
methAst
.collectFirst { case x: Ast if x.root.isDefined && x.root.get.isInstanceOf[NewMethod] => x }
.foreach(blockMethods.addOne)
Seq(callAst(callNode, parenAst))
} else if (callNode.name == Defines.DEFINE_METHOD) {
parenAst.headOption
.flatMap(_.root)
.collect { case x: AstNodeNew => stripQuotes(x.code).stripPrefix(":") }
.map(methodName => astForBlock(ctx.block(), Option(methodName)))
.getOrElse(Seq.empty)
} else {
val blockAst = Seq(astForBlock(ctx.block()))
val blockAst = astForBlock(ctx.block())
Seq(callAst(callNode, parenAst ++ blockAst))
}
} else
Expand All @@ -589,7 +584,7 @@ class AstCreator(

def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = {
val name = if (isYieldBlock) {
s"${resolveAlias(text(ctx))}$YIELD_SUFFIX"
s"${resolveAlias(text(ctx))}${Defines.YIELD_SUFFIX}"
} else {
val calleeName = resolveAlias(text(ctx))
// Add the call name to the global builtIn callNames set
Expand Down Expand Up @@ -664,27 +659,27 @@ class AstCreator(
private def astForCommandWithDoBlockContext(ctx: CommandWithDoBlockContext): Seq[Ast] = ctx match {
case ctx: ArgsAndDoBlockCommandWithDoBlockContext =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAst = Seq(astForDoBlock(ctx.doBlock()))
val doBlockAst = astForDoBlock(ctx.doBlock())
argsAsts ++ doBlockAst
case ctx: RubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext =>
val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx))
methodIdAsts.headOption.flatMap(_.root) match
case Some(methodIdRoot: NewCall) if methodIdRoot.name == "define_method" =>
case Some(methodIdRoot: NewCall) if methodIdRoot.name == Defines.DEFINE_METHOD =>
ctx.argumentsWithoutParentheses.arguments.argument.asScala.headOption
.map { methodArg =>
// TODO: methodArg will name the method, but this could be an identifier or even a string concatenation
// which is not assumed below
val methodName = stripQuotes(methodArg.getText)
Seq(astForDoBlock(ctx.doBlock(), Option(methodName)))
val methodName = stripQuotes(methodArg.getText).stripPrefix(":")
astForDoBlock(ctx.doBlock(), Option(methodName))
}
.getOrElse(Seq.empty)
case _ =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAsts = Seq(astForDoBlock(ctx.doBlock()))
val doBlockAsts = astForDoBlock(ctx.doBlock())
methodIdAsts ++ argsAsts ++ doBlockAsts
case ctx: RubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext =>
val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments())
val doBlockAsts = Seq(astForDoBlock(ctx.doBlock()))
val doBlockAsts = astForDoBlock(ctx.doBlock())
val methodNameAsts = astForMethodNameContext(ctx.methodName())
val primaryAsts = astForPrimaryContext(ctx.primary())
primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -196,12 +196,20 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
.lineNumber(ctx.op.getLine)
.columnNumber(ctx.op.getCharPositionInLine)
if (leftAst.size == 1 && rightAst.size > 1) {
/*
* This is multiple RHS packed into a single LHS. That is, packing left hand side.
* This is as good as multiple RHS packed into an array and put into a single LHS
*/
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
Seq(callAst(opCallNode, leftAst ++ packedRHS))
if (rightAst.headOption.flatMap(_.root).exists(_.isInstanceOf[NewMethod])) {
/*
* Here we expect to be assigned the result of some dynamically defined function's application to some variable
*/
val lastAst = rightAst.takeRight(1)
rightAst.filterNot(_ == lastAst.head) ++ Seq(callAst(opCallNode, leftAst ++ lastAst))
} else {
/*
* This is multiple RHS packed into a single LHS. That is, packing left hand side.
* This is as good as multiple RHS packed into an array and put into a single LHS
*/
val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true)
Seq(callAst(opCallNode, leftAst ++ packedRHS))
}
} else {
Seq(callAst(opCallNode, leftAst ++ rightAst))
}
Expand Down Expand Up @@ -374,7 +382,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = {
val args = argumentsCtx.map(astForArguments).getOrElse(Seq())
val call = callNode(ctx, text(ctx), UNRESOLVED_YIELD, UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH)
val call =
callNode(ctx, text(ctx), Defines.UNRESOLVED_YIELD, Defines.UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH)
callAst(call, args)
}

Expand Down
Loading