diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index c3f9852e185f..199116b715e9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -106,7 +106,11 @@ class AstCreator( programCtx.compoundStatement() != null && programCtx.compoundStatement().statements() != null ) { - astForStatements(programCtx.compoundStatement().statements(), false, false) ++ blockMethods + val stmts = astForStatements(programCtx.compoundStatement().statements(), false, false) + val methodNodes = stmts.flatMap(_.nodes).collect { case x: NewMethod => x }.toSet + // Block methods is largely unnecessary, but will keep it for AST functions that still populate it instead of + // attaching it to AST + stmts ++ blockMethods.filterNot(_.root.collect { case x: NewMethod => x }.exists(methodNodes.contains)) } else { logger.error(s"File $filename has no compound statement. Needs to be examined") List[Ast](Ast()) @@ -139,23 +143,7 @@ class AstCreator( .filterNot(_.astParentType == NodeTypes.TYPE_DECL) .map { methodNode => // Create a methodRefNode and assign it to the identifier version of the method, which will help in type propagation to resolve calls - val methodRefNode = NewMethodRef() - .code("def " + methodNode.name + "(...)") - .methodFullName(methodNode.fullName) - .typeFullName(methodNode.fullName) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - - val methodNameIdentifier = NewIdentifier() - .code(methodNode.name) - .name(methodNode.name) - .typeFullName(Defines.Any) - .lineNumber(lineColNum) - .columnNumber(lineColNum) - scope.addToScope(methodNode.name, methodNameIdentifier) - val methodRefAssignmentAst = - astForAssignment(methodNameIdentifier, methodRefNode, methodNode.lineNumber, methodNode.columnNumber) - methodRefAssignmentAst + methodRefAssignmentFromMethod(methodNode, Option(lineColNum), Option(lineColNum)) } .toList @@ -342,29 +330,28 @@ class AstCreator( lineEnd(ctx).head, columnEnd(ctx).head ) - val blockMethodNode = - blockMethodAsts.head.nodes.head - .asInstanceOf[NewMethod] - blockMethods.addOne(blockMethodAsts.head) - - val callNode = NewCall() - .name(blockName) - .methodFullName(blockMethodNode.fullName) - .typeFullName(Defines.Any) - .code(blockMethodNode.code) - .lineNumber(blockMethodNode.lineNumber) - .columnNumber(blockMethodNode.columnNumber) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - - val methodRefNode = NewMethodRef() - .methodFullName(blockMethodNode.fullName) - .typeFullName(Defines.Any) - .code(blockMethodNode.code) - .lineNumber(blockMethodNode.lineNumber) - .columnNumber(blockMethodNode.columnNumber) + blockMethodAsts.foreach { ast => + ast.root match + case Some(_: NewMethod) => blockMethods.addOne(ast) + case _ => + } - Seq(callAst(callNode, argsAst ++ Seq(Ast(methodRefNode)), primaryAst.headOption)) + blockMethodAsts :+ blockMethodAsts + .flatMap(_.nodes) + .collectFirst { case methodRefNode: NewMethodRef => + val callNode = NewCall() + .name(blockName) + .methodFullName(methodRefNode.methodFullName) + .typeFullName(Defines.Any) + .code(methodRefNode.code) + .lineNumber(methodRefNode.lineNumber) + .columnNumber(methodRefNode.columnNumber) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + // TODO: primaryAst.headOption is broken when primaryAst is an array + callAst(callNode, argsAst ++ Seq(Ast(methodRefNode.copy)), primaryAst.headOption) + } + .getOrElse(Ast()) } else { val callNode = methodNameAst.head.nodes .filter(node => node.isInstanceOf[NewCall]) @@ -422,15 +409,15 @@ class AstCreator( val baseAst = astForPrimaryContext(ctx.primary()) val blocksAst = if (ctx.block() != null) { - Seq(astForBlock(ctx.block())) + astForBlock(ctx.block()) } else { Seq() } val callNode = methodNameAst.head.nodes.filter(node => node.isInstanceOf[NewCall]).head.asInstanceOf[NewCall] callNode .code(text(ctx)) - .lineNumber(ctx.COLON2().getSymbol().getLine()) - .columnNumber(ctx.COLON2().getSymbol().getCharPositionInLine()) + .lineNumber(ctx.COLON2.lineNumber) + .columnNumber(ctx.COLON2.columnNumber) Seq(callAst(callNode, baseAst ++ blocksAst)) } @@ -534,8 +521,8 @@ class AstCreator( .asInstanceOf[NewCall] .name - val isYieldMethod = if (blockName.endsWith(YIELD_SUFFIX)) { - val lookupMethodName = blockName.take(blockName.length - YIELD_SUFFIX.length) + val isYieldMethod = if (blockName.endsWith(Defines.YIELD_SUFFIX)) { + val lookupMethodName = blockName.take(blockName.length - Defines.YIELD_SUFFIX.length) methodNamesWithYield.contains(lookupMethodName) } else { false @@ -555,7 +542,7 @@ class AstCreator( columnEnd(ctx).head ) } else { - val blockAst = Seq(astForBlock(ctx.block())) + val blockAst = astForBlock(ctx.block()) // this is expected to be a call node val callNode = methodIdAst.head.nodes.head.asInstanceOf[NewCall] Seq(callAst(callNode, blockAst)) @@ -569,18 +556,26 @@ class AstCreator( callNode.name(resolveAlias(callNode.name)) if (ctx.block() != null) { - val isYieldMethod = if (callNode.name.endsWith(YIELD_SUFFIX)) { - val lookupMethodName = callNode.name.take(callNode.name.length - YIELD_SUFFIX.length) + val isYieldMethod = if (callNode.name.endsWith(Defines.YIELD_SUFFIX)) { + val lookupMethodName = callNode.name.take(callNode.name.length - Defines.YIELD_SUFFIX.length) methodNamesWithYield.contains(lookupMethodName) } else { false } if (isYieldMethod) { val methAst = astForBlock(ctx.block(), Some(callNode.name)) - blockMethods.addOne(methAst) + methAst + .collectFirst { case x: Ast if x.root.isDefined && x.root.get.isInstanceOf[NewMethod] => x } + .foreach(blockMethods.addOne) Seq(callAst(callNode, parenAst)) + } else if (callNode.name == Defines.DEFINE_METHOD) { + parenAst.headOption + .flatMap(_.root) + .collect { case x: AstNodeNew => stripQuotes(x.code).stripPrefix(":") } + .map(methodName => astForBlock(ctx.block(), Option(methodName))) + .getOrElse(Seq.empty) } else { - val blockAst = Seq(astForBlock(ctx.block())) + val blockAst = astForBlock(ctx.block()) Seq(callAst(callNode, parenAst ++ blockAst)) } } else @@ -589,7 +584,7 @@ class AstCreator( def astForCallNode(ctx: ParserRuleContext, code: String, isYieldBlock: Boolean = false): Ast = { val name = if (isYieldBlock) { - s"${resolveAlias(text(ctx))}$YIELD_SUFFIX" + s"${resolveAlias(text(ctx))}${Defines.YIELD_SUFFIX}" } else { val calleeName = resolveAlias(text(ctx)) // Add the call name to the global builtIn callNames set @@ -664,27 +659,27 @@ class AstCreator( private def astForCommandWithDoBlockContext(ctx: CommandWithDoBlockContext): Seq[Ast] = ctx match { case ctx: ArgsAndDoBlockCommandWithDoBlockContext => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAst = Seq(astForDoBlock(ctx.doBlock())) + val doBlockAst = astForDoBlock(ctx.doBlock()) argsAsts ++ doBlockAst case ctx: RubyParser.ArgsAndDoBlockAndMethodIdCommandWithDoBlockContext => val methodIdAsts = astForMethodIdentifierContext(ctx.methodIdentifier(), text(ctx)) methodIdAsts.headOption.flatMap(_.root) match - case Some(methodIdRoot: NewCall) if methodIdRoot.name == "define_method" => + case Some(methodIdRoot: NewCall) if methodIdRoot.name == Defines.DEFINE_METHOD => ctx.argumentsWithoutParentheses.arguments.argument.asScala.headOption .map { methodArg => // TODO: methodArg will name the method, but this could be an identifier or even a string concatenation // which is not assumed below - val methodName = stripQuotes(methodArg.getText) - Seq(astForDoBlock(ctx.doBlock(), Option(methodName))) + val methodName = stripQuotes(methodArg.getText).stripPrefix(":") + astForDoBlock(ctx.doBlock(), Option(methodName)) } .getOrElse(Seq.empty) case _ => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) + val doBlockAsts = astForDoBlock(ctx.doBlock()) methodIdAsts ++ argsAsts ++ doBlockAsts case ctx: RubyParser.PrimaryMethodArgsDoBlockCommandWithDoBlockContext => val argsAsts = astForArguments(ctx.argumentsWithoutParentheses().arguments()) - val doBlockAsts = Seq(astForDoBlock(ctx.doBlock())) + val doBlockAsts = astForDoBlock(ctx.doBlock()) val methodNameAsts = astForMethodNameContext(ctx.methodName()) val primaryAsts = astForPrimaryContext(ctx.primary()) primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 44d2c8bc4e09..51f9917d7f20 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -4,7 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.* import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier} +import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators} import org.antlr.v4.runtime.ParserRuleContext import org.slf4j.LoggerFactory @@ -196,12 +196,20 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { .lineNumber(ctx.op.getLine) .columnNumber(ctx.op.getCharPositionInLine) if (leftAst.size == 1 && rightAst.size > 1) { - /* - * This is multiple RHS packed into a single LHS. That is, packing left hand side. - * This is as good as multiple RHS packed into an array and put into a single LHS - */ - val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true) - Seq(callAst(opCallNode, leftAst ++ packedRHS)) + if (rightAst.headOption.flatMap(_.root).exists(_.isInstanceOf[NewMethod])) { + /* + * Here we expect to be assigned the result of some dynamically defined function's application to some variable + */ + val lastAst = rightAst.takeRight(1) + rightAst.filterNot(_ == lastAst.head) ++ Seq(callAst(opCallNode, leftAst ++ lastAst)) + } else { + /* + * This is multiple RHS packed into a single LHS. That is, packing left hand side. + * This is as good as multiple RHS packed into an array and put into a single LHS + */ + val packedRHS = getPackedRHS(rightAst, wrapInBrackets = true) + Seq(callAst(opCallNode, leftAst ++ packedRHS)) + } } else { Seq(callAst(opCallNode, leftAst ++ rightAst)) } @@ -374,7 +382,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForYieldCall(ctx: ParserRuleContext, argumentsCtx: Option[ArgumentsContext]): Ast = { val args = argumentsCtx.map(astForArguments).getOrElse(Seq()) - val call = callNode(ctx, text(ctx), UNRESOLVED_YIELD, UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH) + val call = + callNode(ctx, text(ctx), Defines.UNRESOLVED_YIELD, Defines.UNRESOLVED_YIELD, DispatchTypes.STATIC_DISPATCH) callAst(call, args) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala index 5e595a231378..137c87a20474 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -19,17 +19,6 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema private val logger = LoggerFactory.getLogger(getClass) - /* - *Fake methods created from yield blocks and their yield calls will have this suffix in their names - */ - protected val YIELD_SUFFIX = "_yield" - - /* - * This is used to mark call nodes created due to yield calls. This is set in their names at creation. - * The appropriate name wrt the names of their actual methods is set later in them. - */ - protected val UNRESOLVED_YIELD = "unresolved_yield" - /* * Stack of variable identifiers incorrectly identified as method identifiers * Each AST contains exactly one call or identifier node @@ -101,12 +90,12 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema // process yield calls. astBody - .flatMap(_.nodes.collect { case x: NewCall => x }.filter(_.name == UNRESOLVED_YIELD)) + .flatMap(_.nodes.collect { case x: NewCall => x }.filter(_.name == Defines.UNRESOLVED_YIELD)) .foreach { yieldCallNode => val name = newMethodNode.name val methodFullName = classStack.reverse :+ callNode.name mkString pathSep - yieldCallNode.name(name + YIELD_SUFFIX) - yieldCallNode.methodFullName(methodFullName + YIELD_SUFFIX) + yieldCallNode.name(name + Defines.YIELD_SUFFIX) + yieldCallNode.methodFullName(methodFullName + Defines.YIELD_SUFFIX) methodNamesWithYield.add(newMethodNode.name) /* * These are calls to the yield block of this method. @@ -348,6 +337,8 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema } } + /** Creates a method, methodRef, and type decl binding for this block method. + */ protected def astForBlockFunction( ctxStmt: StatementsContext, ctxParam: Option[BlockParameterContext], @@ -363,17 +354,17 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema val methodFullName = classStack.reverse :+ blockMethodName mkString pathSep val newMethodNode = methodNode(ctxStmt, blockMethodName, text(ctxStmt), methodFullName, None, relativeFilename) .lineNumber(lineStart) - .lineNumberEnd(lineEnd) + .lineNumberEnd(lineEnd + 1) // this requires a +1 due to the `end` token .columnNumber(colStart) .columnNumberEnd(colEnd) scope.pushNewScope(newMethodNode) val astMethodParam = ctxParam.map(astForBlockParameterContext).getOrElse(Seq()) - val publicModifier = NewModifier().modifierType(ModifierTypes.PUBLIC) val paramSeq = astMethodParam.flatMap(_.root).map { /* In majority of cases, node will be an identifier */ case identifierNode: NewIdentifier => + scope.removeFromScope(identifierNode) val param = NewMethodParameterIn() .name(identifierNode.name) .code(identifierNode.code) @@ -389,19 +380,28 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema case _ => Ast() } - val paramNames = (astMethodParam ++ paramSeq) + val paramNames = paramSeq .flatMap(_.root) .collect { case x: NewMethodParameterIn => x.name case x: NewIdentifier => x.name } .toSet - val astBody = astForStatements(ctxStmt, true) - val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) + + val astBody = astForStatements(ctxStmt, true) + val locals = scope.createAndLinkLocalNodes(diffGraph, paramNames).map(Ast.apply) + paramSeq.flatMap(_.root).collect { case x: NewMethodParameterIn => x }.foreach(scope.linkParamNode(diffGraph, _)) val methodRetNode = NewMethodReturn().typeFullName(Defines.Any) scope.popScope() + // Create a method ref & type binding for this node + val methodRefAssignmentAst = methodRefAssignmentFromMethod(newMethodNode) + val binding = NewBinding() + .name(blockMethodName) + .methodFullName(methodFullName) + val typeDecl = typeDeclFromMethod(newMethodNode) + Seq( methodAst( newMethodNode, @@ -409,8 +409,69 @@ trait AstForFunctionsCreator(packageContext: PackageContext)(implicit withSchema blockAst(blockNode(ctxStmt), locals ++ astBody.toList), methodRetNode, Seq(publicModifier) - ) + ), + methodRefAssignmentAst, + Ast(typeDecl).withBindsEdge(typeDecl, binding).withRefEdge(binding, newMethodNode) ) } + private def methodPositionWithFallback( + method: NewMethod, + lineNum: Option[Integer] = None, + colNum: Option[Integer] = None + ): (Option[Integer], Option[Integer]) = { + val lineNumber = lineNum match + case Some(x) => Some(x) + case None if method.lineNumber.isDefined => method.lineNumber + case _ => None + val columnNumber = colNum match + case Some(x) => Some(x) + case None if method.columnNumber.isDefined => method.columnNumber + case _ => None + + (lineNumber, columnNumber) + } + + /** Creates a method ref node assigned to an identifier of the same name from a method and adds the identifier to the + * scope. + */ + protected def methodRefAssignmentFromMethod( + method: NewMethod, + lineNum: Option[Integer] = None, + colNum: Option[Integer] = None + ): Ast = { + val (lineNumber, columnNumber) = methodPositionWithFallback(method, lineNum, colNum) + val methodRefNode = NewMethodRef() + .code("def " + method.name + "(...)") + .methodFullName(method.fullName) + .typeFullName(method.fullName) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + + val methodNameIdentifier = NewIdentifier() + .code(method.name) + .name(method.name) + .typeFullName(Defines.Any) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + scope.addToScope(method.name, methodNameIdentifier) + val methodRefAssignmentAst = + astForAssignment(methodNameIdentifier, methodRefNode, lineNumber, columnNumber) + methodRefAssignmentAst + } + + protected def typeDeclFromMethod( + method: NewMethod, + lineNum: Option[Integer] = None, + colNum: Option[Integer] = None + ): NewTypeDecl = { + val (lineNumber, columnNumber) = methodPositionWithFallback(method, lineNum, colNum) + NewTypeDecl() + .code("def " + method.name + "(...)") + .name(method.name) + .fullName(method.fullName) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + } + } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 432b4944f4af..c71d81af4d37 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -369,7 +369,7 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V } } - protected def astForBlock(ctx: BlockContext, blockMethodName: Option[String] = None): Ast = ctx match + protected def astForBlock(ctx: BlockContext, blockMethodName: Option[String] = None): Seq[Ast] = ctx match case ctx: DoBlockBlockContext => astForDoBlock(ctx.doBlock(), blockMethodName) case ctx: BraceBlockBlockContext => astForBraceBlock(ctx.braceBlock(), blockMethodName) @@ -378,7 +378,7 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V blockParamCtx: Option[BlockParameterContext], compoundStmtCtx: CompoundStatementContext, blockMethodName: Option[String] = None - ) = { + ): Seq[Ast] = { blockMethodName match { case Some(blockMethodName) => astForBlockFunction( @@ -389,20 +389,20 @@ trait AstForStatementsCreator(filename: String)(implicit withSchemaValidation: V lineEnd(compoundStmtCtx).head, column(compoundStmtCtx).head, columnEnd(compoundStmtCtx).head - ).head + ) case None => val blockNode_ = blockNode(ctx, text(ctx), Defines.Any) val blockBodyAst = astForCompoundStatement(compoundStmtCtx) val blockParamAst = blockParamCtx.flatMap(astForBlockParameterContext) - blockAst(blockNode_, blockBodyAst.toList ++ blockParamAst) + Seq(blockAst(blockNode_, blockBodyAst.toList ++ blockParamAst)) } } - protected def astForDoBlock(ctx: DoBlockContext, blockMethodName: Option[String] = None): Ast = { + protected def astForDoBlock(ctx: DoBlockContext, blockMethodName: Option[String] = None): Seq[Ast] = { astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) } - private def astForBraceBlock(ctx: BraceBlockContext, blockMethodName: Option[String] = None): Ast = { + private def astForBraceBlock(ctx: BraceBlockContext, blockMethodName: Option[String] = None): Seq[Ast] = { astForBlockHelper(ctx, Option(ctx.blockParameter), ctx.bodyStatement().compoundStatement(), blockMethodName) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala index 393a860cbc44..5871be927e8f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyScope.scala @@ -1,8 +1,9 @@ package io.joern.rubysrc2cpg.astcreation +import io.joern.x2cpg.Ast import io.joern.x2cpg.datastructures.Scope import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.nodes.{DeclarationNew, NewIdentifier, NewLocal, NewNode} +import io.shiftleft.codepropertygraph.generated.nodes.* import overflowdb.BatchedUpdate import scala.collection.mutable @@ -30,6 +31,10 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { scopeNode } + def removeFromScope(variable: NewIdentifier): Unit = { + stack.headOption.foreach(head => scopeToVarMap.removeIdentifierFromVarGroup(head.scopeNode, variable)) + } + override def popScope(): Option[NewNode] = { stack.headOption.map(_.scopeNode).foreach(scopeToVarMap.remove) super.popScope() @@ -42,9 +47,18 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { def createAndLinkLocalNodes( diffGraph: BatchedUpdate.DiffGraphBuilder, paramNames: Set[String] = Set.empty - ): List[DeclarationNew] = stack.headOption match - case Some(top) => scopeToVarMap.buildVariableGroupings(top.scopeNode, paramNames ++ Set("this"), diffGraph) - case None => List.empty[DeclarationNew] + ): List[DeclarationNew] = { + stack.headOption match + case Some(top) => scopeToVarMap.buildVariableGroupings(top.scopeNode, paramNames ++ Set("this"), diffGraph) + case None => List.empty[DeclarationNew] + } + + /** Links the parameter node to the referenced identifiers in this scope. + */ + def linkParamNode(diffGraph: BatchedUpdate.DiffGraphBuilder, param: NewMethodParameterIn): Unit = + stack.headOption match + case Some(top) => scopeToVarMap.buildParameterGrouping(top.scopeNode, param, diffGraph) + case None => List.empty[DeclarationNew] private implicit class IdentifierExt(node: NewIdentifier) { @@ -77,6 +91,18 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { Some(Map(identifier.name -> identifier.toNewVarGroup)) } + /** Removes an identifier from the var group. + */ + def removeIdentifierFromVarGroup(key: ScopeNodeType, identifier: NewIdentifier): Unit = + scopeMap.updateWith(key) { + case Some(varMap: VarMap) => + Some(varMap.updatedWith(identifier.name) { + case Some(varGroup: VarGroup) => Some(varGroup.copy(ids = varGroup.ids.filterNot(_ == identifier))) + case None => None + }) + case None => None + } + /** Will persist the variable groupings that do not represent parameter nodes and link them with REF edges. * @return * the list of persisted local nodes. @@ -96,6 +122,23 @@ class RubyScope extends Scope[String, NewIdentifier, NewNode] { } .toList case None => List.empty[DeclarationNew] + + /** Will persist a REF edge between the given parameter and its corresponding identifiers. + */ + def buildParameterGrouping( + key: ScopeNodeType, + param: NewMethodParameterIn, + diffGraph: BatchedUpdate.DiffGraphBuilder + ): Unit = { + scopeMap + .get(key) + .map(_.values) + .foreach(_.filter { case VarGroup(local, _) => local.name == param.name } + .foreach { case VarGroup(_, ids) => + ids.foreach(id => diffGraph.addEdge(id, param, EdgeTypes.REF)) + }) + } + } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala index 42b32a2d4f69..80edb4858ffa 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala @@ -35,5 +35,21 @@ object Defines { // Constructor method val Initialize = "initialize" + /* + * Fake methods created from yield blocks and their yield calls will have this suffix in their names + */ + val YIELD_SUFFIX = "_yield" + + /* + * This is used to mark call nodes created due to yield calls. This is set in their names at creation. + * The appropriate name wrt the names of their actual methods is set later in them. + */ + val UNRESOLVED_YIELD = "unresolved_yield" + + /* + * Ruby provides a dynamic method declaration via its metaprogramming keyword `define_method`. + */ + val DEFINE_METHOD = "define_method" + def getBuiltInType(typeInString: String) = s"${GlobalTypes.builtinPrefix}.$typeInString" } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala index 86bb049f5cf4..8f8605204d4a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/MethodTwoTests.scala @@ -1,9 +1,7 @@ package io.joern.rubysrc2cpg.passes.ast import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes} import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class MethodTwoTests extends RubyCode2CpgFixture { @@ -14,88 +12,76 @@ class MethodTwoTests extends RubyCode2CpgFixture { |end |""".stripMargin) - // TODO: This test cases needs to be fixed. - "should contain exactly one method node with correct fields" ignore { + "should contain exactly one method node with correct fields" in { inside(cpg.method.name("foo").l) { case List(x) => x.name shouldBe "foo" x.isExternal shouldBe false - x.fullName shouldBe "Test0.rb::program:foo" - x.code should startWith("def foo(a, b)") + x.fullName shouldBe "Test0.rb::program.foo" + x.code should startWith("return \"\"") x.isExternal shouldBe false - x.order shouldBe 1 + x.order shouldBe 2 x.filename endsWith "Test0.rb" x.lineNumber shouldBe Option(2) x.lineNumberEnd shouldBe Option(4) } } - // TODO: This test cases needs to be fixed. - "should return correct number of lines" ignore { + "should return correct number of lines" in { cpg.method.name("foo").numberOfLines.l shouldBe List(3) } - // TODO: This test cases needs to be fixed. - "should allow traversing to parameters" ignore { + "should allow traversing to parameters" in { cpg.method.name("foo").parameter.name.toSetMutable shouldBe Set("a", "b") } - // TODO: This test cases needs to be fixed. - "should allow traversing to methodReturn" ignore { + "should allow traversing to methodReturn" in { cpg.method.name("foo").methodReturn.l.size shouldBe 1 cpg.method.name("foo").methodReturn.typeFullName.head shouldBe "ANY" } - // TODO: This test cases needs to be fixed. - "should allow traversing to method" ignore { - cpg.methodReturn.method.name.l shouldBe List("foo", ":program") + "should allow traversing to method" in { + cpg.methodReturn.method.isExternal(false).name.l shouldBe List("foo", ":program") } - // TODO: This test cases needs to be fixed. - "should allow traversing to file" ignore { + "should allow traversing to file" in { cpg.method.name("foo").file.name.l should not be empty } - // TODO: Need to be fixed - "test function method ref" ignore { + "test function method ref" in { cpg.methodRef("foo").referencedMethod.fullName.l should not be empty - cpg.methodRef("foo").referencedMethod.fullName.head shouldBe - "Test0.rb::program:foo" + cpg.methodRef("foo").referencedMethod.fullName.head shouldBe "Test0.rb::program.foo" } - // TODO: Need to be fixed. - "test existence of local variable in module function" ignore { + "test existence of local variable in module function" in { cpg.method.fullName("Test0.rb::program").local.name.l should contain("foo") } - // TODO: need to be fixed. - "test corresponding type, typeDecl and binding" ignore { - cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.l should not be empty + "test corresponding type, typeDecl and binding" in { + cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.l should not be empty val bindingTypeDecl = - cpg.method.fullName("Test0.rb::program:foo").referencingBinding.bindingTypeDecl.head + cpg.method.fullName("Test0.rb::program.foo").referencingBinding.bindingTypeDecl.head bindingTypeDecl.name shouldBe "foo" - bindingTypeDecl.fullName shouldBe "Test0.rb::program:foo" + bindingTypeDecl.fullName shouldBe "Test0.rb::program.foo" bindingTypeDecl.referencingType.name.head shouldBe "foo" - bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program:foo" + bindingTypeDecl.referencingType.fullName.head shouldBe "Test0.rb::program.foo" } - // TODO: Need to be fixed - "test method parameter nodes" ignore { + "test method parameter nodes" in { cpg.method.name("foo").parameter.name.l.size shouldBe 2 - val parameter1 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(1).head + val parameter1 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(1).head parameter1.name shouldBe "a" parameter1.index shouldBe 1 parameter1.typeFullName shouldBe "ANY" - val parameter2 = cpg.method.fullName("Test0.rb::program:foo").parameter.order(2).head + val parameter2 = cpg.method.fullName("Test0.rb::program.foo").parameter.order(2).head parameter2.name shouldBe "b" parameter2.index shouldBe 2 parameter2.typeFullName shouldBe "ANY" } - // TODO: Need to be fixed - "should allow traversing from parameter to method" ignore { + "should allow traversing from parameter to method" in { cpg.parameter.name("a").method.name.l shouldBe List("foo") cpg.parameter.name("b").method.name.l shouldBe List("foo") } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala index 9e5ccfe1d5ef..97909ef2c446 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/passes/ast/SimpleAstCreationPassTest.scala @@ -733,7 +733,7 @@ class SimpleAstCreationPassTest extends RubyCode2CpgFixture { val cpg = code("object::foo do\nputs \"right here\"\nend") val List(callNode1) = cpg.call.name("foo").l - callNode1.code shouldBe "puts \"right here\"" + callNode1.code shouldBe "def foo1(...)" callNode1.lineNumber shouldBe Some(1) callNode1.columnNumber shouldBe Some(3) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala index 48e197ddcbd5..18c7cfd4625e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ControlStructureTests.scala @@ -4,6 +4,7 @@ import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.nodes.{Block, ControlStructure} import io.shiftleft.semanticcpg.language.* + class ControlStructureTests extends RubyCode2CpgFixture { "CPG for code with doBlock iterating over a constant array" should { @@ -15,7 +16,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 1 - cpg.identifier.size shouldBe 2 // 1 identifier node is for `puts = typeDef(__builtin.puts)` + cpg.identifier.size shouldBe 3 // 1 identifier node is for `puts = typeDef(__builtin.puts)` and similarly for `each2` } "recognize all call nodes" in { @@ -56,8 +57,7 @@ class ControlStructureTests extends RubyCode2CpgFixture { "recognise all identifier nodes" in { cpg.identifier.name("n").size shouldBe 2 cpg.identifier.name("m").size shouldBe 1 - cpg.identifier.size shouldBe 5 - cpg.method.name("fakeName").dotAst.l + cpg.identifier.size shouldBe 6 // includes each2 = def each2(...) } "recognize all call nodes" in { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala index 76997217302a..d030f7f3a01a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MiscTests.scala @@ -54,7 +54,7 @@ class MiscTests extends RubyCode2CpgFixture { cpg.identifier.name("Formatter").size shouldBe 1 cpg.identifier.name("Logger").size shouldBe 1 cpg.identifier.name("log_formatter").size shouldBe 1 - cpg.identifier.size shouldBe 5 + cpg.identifier.size shouldBe 6 } }