diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 730c218333eb..e8ef0f35fac2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -266,16 +266,17 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } protected def astForObjectInstantiation(node: RubyNode & ObjectInstantiation): Ast = { - val className = node.target.text - val callName = "new" - val methodName = Defines.Initialize /* We short-cut the call edge from `new` call to `initialize` method, however we keep the modelling of the receiver as referring to the singleton class. */ - val (receiverTypeFullName, fullName) = scope.tryResolveTypeReference(className) match { - case Some(typeMetaData) => s"${typeMetaData.name}" -> s"${typeMetaData.name}.$methodName" - case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName + val (receiverTypeFullName, fullName) = node.target match { + case x: (SimpleIdentifier | MemberAccess) => + scope.tryResolveTypeReference(x.text) match { + case Some(typeMetaData) => s"${typeMetaData.name}" -> s"${typeMetaData.name}.${Defines.Initialize}" + case None => XDefines.Any -> XDefines.DynamicCallUnknownFullName + } + case _ => XDefines.Any -> XDefines.DynamicCallUnknownFullName } /* Similarly to some other frontends, we lower the constructor into two operations, e.g., @@ -287,7 +288,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val tmpName = tmpGen.fresh val tmpTypeHint = receiverTypeFullName.stripSuffix("") - val tmp = SimpleIdentifier(Option(className))(node.span.spanStart(tmpName)) + val tmp = SimpleIdentifier(None)(node.span.spanStart(tmpName)) val tmpLocal = NewLocal().name(tmpName).code(tmpName).dynamicTypeHintFullName(Seq(tmpTypeHint)) scope.addToScope(tmpName, tmpLocal) @@ -298,12 +299,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } // Assign tmp to - val receiverAst = Ast(identifierNode(node, className, className, receiverTypeFullName)) - val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH) - val allocAst = callAst(allocCall, Seq.empty, Option(receiverAst)) + val allocCall = callNode(node, code(node), Operators.alloc, Operators.alloc, DispatchTypes.STATIC_DISPATCH) + val allocAst = callAst(allocCall, Seq.empty) val assignmentCall = callNode( node, - s"${tmp.text} = ${code(node)}", + s"${tmp.text} = ${code(node.target)}.${Defines.Initialize}", Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH @@ -318,8 +318,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { x.arguments.map(astForMethodCallArgument) :+ typeRef } - val constructorCall = callNode(node, code(node), callName, fullName, DispatchTypes.DYNAMIC_DISPATCH) - val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier)) + val constructorCall = + callNode(node, code(node), Defines.Initialize, Defines.Any, DispatchTypes.DYNAMIC_DISPATCH) + if fullName != XDefines.DynamicCallUnknownFullName then constructorCall.dynamicTypeHintFullName(Seq(fullName)) + val constructorRecv = astForExpression(MemberAccess(node.target, ".", Defines.Initialize)(node.span)) + val constructorCallAst = callAst(constructorCall, argumentAsts, Option(tmpIdentifier), Option(constructorRecv)) val retIdentifierAst = tmpIdentifier scope.popScope() @@ -864,8 +867,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForFieldAccess(node: MemberAccess, stripLeadingAt: Boolean = false): Ast = { val (memberName, memberCode) = node.target match { - case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") - case _: TypeIdentifier => node.memberName -> node.memberName + case _ if node.memberName == Defines.Initialize => Defines.Initialize -> Defines.Initialize + case _ if stripLeadingAt => node.memberName -> node.memberName.stripPrefix("@") + case _: TypeIdentifier => node.memberName -> node.memberName case _ if !node.memberName.startsWith("@") && node.memberName.headOption.exists(_.isLower) => s"@${node.memberName}" -> node.memberName case _ => node.memberName -> node.memberName diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala index 304724bab700..5e79d660cca2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala @@ -692,10 +692,22 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitMemberAccessCommand(ctx: RubyParser.MemberAccessCommandContext): RubyNode = { - val args = ctx.commandArgument.arguments.map(visit) - val methodName = visit(ctx.methodName()) - val base = visit(ctx.primary()) - MemberCall(base, ".", methodName.text, args)(ctx.toTextSpan) + val args = ctx.commandArgument.arguments.map(visit) + val base = visit(ctx.primary()) + + if (ctx.methodName().getText == "new") { + base match { + case SingleAssignment(lhs, op, rhs) => + // fixme: Parser packaging arguments from a parenthesis-less object instantiation is odd + val assignSpan = base.span.spanStart(s"${base.span.text}.new") + val rhsSpan = rhs.span.spanStart(s"${rhs.span.text}.new") + SingleAssignment(lhs, op, SimpleObjectInstantiation(rhs, args)(rhsSpan))(assignSpan) + case _ => SimpleObjectInstantiation(base, args)(ctx.toTextSpan) + } + } else { + val methodName = visit(ctx.methodName()) + MemberCall(base, ".", methodName.text, args)(ctx.toTextSpan) + } } override def visitConstantIdentifierVariable(ctx: RubyParser.ConstantIdentifierVariableContext): RubyNode = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala index 2d51a8b6aa45..3054d6c6db70 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ArrayTests.scala @@ -7,7 +7,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.{Call, Literal} import io.shiftleft.semanticcpg.language.* import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.Defines as XDefines -import io.shiftleft.codepropertygraph.generated.nodes.Literal class ArrayTests extends RubyCode2CpgFixture { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala index 1a09fbdb3169..d189f3c54fb7 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CallTests.scala @@ -6,7 +6,7 @@ import io.joern.rubysrc2cpg.passes.GlobalTypes.kernelPrefix import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators} import io.shiftleft.semanticcpg.language.* class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { @@ -155,13 +155,15 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { "a simple object instantiation" should { val cpg = code("""class A + | def initialize(a, b) + | end |end | - |a = A.new + |a = A.new 1, 2 |""".stripMargin) - "create an assignment from `a` to an invocation block" in { - inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("a")).l) { + "create an assignment from `a` to an alloc lowering invocation block" in { + inside(cpg.method.isModule.assignment.and(_.target.isIdentifier.name("a"), _.source.isBlock).l) { case assignment :: Nil => assignment.code shouldBe "a = A.new" inside(assignment.argument.l) { @@ -174,7 +176,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { } } - "create an assignment from a temp variable to the call" in { + "create an assignment from a temp variable to the alloc call" in { inside(cpg.method.isModule.assignment.where(_.target.isIdentifier.name("")).l) { case assignment :: Nil => inside(assignment.argument.l) { @@ -184,6 +186,7 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { alloc.name shouldBe Operators.alloc alloc.methodFullName shouldBe Operators.alloc alloc.code shouldBe "A.new" + alloc.argument.size shouldBe 0 case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") } case xs => fail(s"Expected a single assignment, got [${xs.code.mkString(",")}]") @@ -191,15 +194,68 @@ class CallTests extends RubyCode2CpgFixture(withPostProcessing = true) { } "create a call to the object's constructor, with the temp variable receiver" in { - inside(cpg.call.nameExact("new").l) { + inside(cpg.call.nameExact(RubyDefines.Initialize).l) { case constructor :: Nil => inside(constructor.argument.l) { - case (a: Identifier) :: Nil => + case (a: Identifier) :: (one: Literal) :: (two: Literal) :: Nil => a.name shouldBe "" a.typeFullName shouldBe s"Test0.rb:$Main.A" a.argumentIndex shouldBe 0 + + one.code shouldBe "1" + two.code shouldBe "2" case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") } + + val recv = constructor.receiver.head.asInstanceOf[Call] + recv.methodFullName shouldBe Operators.fieldAccess + recv.name shouldBe Operators.fieldAccess + recv.code shouldBe s"A.${RubyDefines.Initialize}" + + recv.argument(1).label shouldBe NodeTypes.CALL + recv.argument(1).code shouldBe "self.A" + recv.argument(2).label shouldBe NodeTypes.FIELD_IDENTIFIER + recv.argument(2).code shouldBe RubyDefines.Initialize + case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]") + } + } + } + + "an object instantiation from some expression" should { + val cpg = code("""def foo + | params[:type].constantize.new(path) + |end + |""".stripMargin) + + "create a call node on the receiver end of the constructor lowering" in { + inside(cpg.call.nameExact(RubyDefines.Initialize).l) { + case constructor :: Nil => + inside(constructor.argument.l) { + case (a: Identifier) :: (selfPath: Call) :: Nil => + a.name shouldBe "" + a.typeFullName shouldBe Defines.Any + a.argumentIndex shouldBe 0 + + selfPath.code shouldBe "self.path" + case xs => fail(s"Expected one identifier and one call argument, got [${xs.code.mkString(",")}]") + } + + val recv = constructor.receiver.head.asInstanceOf[Call] + recv.methodFullName shouldBe Operators.fieldAccess + recv.name shouldBe Operators.fieldAccess + recv.code shouldBe s"params[:type].constantize.${RubyDefines.Initialize}" + + inside(recv.argument.l) { case (constantize: Call) :: (initialize: FieldIdentifier) :: Nil => + constantize.code shouldBe "params[:type].constantize" + inside(constantize.argument.l) { case (indexAccess: Call) :: (const: FieldIdentifier) :: Nil => + indexAccess.name shouldBe Operators.indexAccess + indexAccess.code shouldBe "params[:type]" + + const.canonicalName shouldBe "constantize" + } + + initialize.canonicalName shouldBe RubyDefines.Initialize + } case xs => fail(s"Expected a single alloc, got [${xs.code.mkString(",")}]") } } diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala index 8b0148d6ae5c..b1c1635dc73d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DependencyTests.scala @@ -95,9 +95,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies = case (v: Identifier) :: (block: Block) :: Nil => v.dynamicTypeHintFullName should contain("dummy_logger.Main_module.Main_outer_class") - inside(block.astChildren.isCall.nameExact("new").headOption) { + inside(block.astChildren.isCall.nameExact(RubyDefines.Initialize).headOption) { case Some(constructorCall) => - constructorCall.methodFullName shouldBe s"dummy_logger.Main_module.Main_outer_class.${RubyDefines.Initialize}" + constructorCall.methodFullName shouldBe Defines.Any case None => fail(s"Expected constructor call, did not find one") } case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]") @@ -109,9 +109,9 @@ class DownloadDependencyTest extends RubyCode2CpgFixture(downloadDependencies = case (g: Identifier) :: (block: Block) :: Nil => g.dynamicTypeHintFullName should contain("dummy_logger.Help") - inside(block.astChildren.isCall.name("new").headOption) { + inside(block.astChildren.isCall.name(RubyDefines.Initialize).headOption) { case Some(constructorCall) => - constructorCall.methodFullName shouldBe s"dummy_logger.Help.${RubyDefines.Initialize}" + constructorCall.methodFullName shouldBe Defines.Any case None => fail(s"Expected constructor call, did not find one") } case xs => fail(s"Expected two arguments under the constructor assignment, got [${xs.code.mkString(", ")}]") diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala index 65e9a7cc7f5c..7f7e68893fae 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala @@ -1,7 +1,7 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.passes.GlobalTypes.builtinPrefix -import io.joern.rubysrc2cpg.passes.Defines.Main +import io.joern.rubysrc2cpg.passes.Defines.{Main, Initialize} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.generated.nodes.* @@ -267,10 +267,11 @@ class DoBlockTests extends RubyCode2CpgFixture { inside(constrBlock.astChildren.l) { case (tmpLocal: Local) :: (tmpAssign: Call) :: (newCall: Call) :: (_: Identifier) :: Nil => tmpLocal.name shouldBe "" - tmpAssign.code shouldBe " = Array.new(x) { |i| i += 1 }" + tmpAssign.code shouldBe s" = Array.$Initialize" - newCall.name shouldBe "new" - newCall.methodFullName shouldBe s"$builtinPrefix.Array.initialize" + newCall.name shouldBe Initialize + newCall.methodFullName shouldBe Defines.Any + newCall.dynamicTypeHintFullName should contain(s"$builtinPrefix.Array.$Initialize") inside(newCall.argument.l) { case (_: Identifier) :: (x: Identifier) :: (closure: TypeRef) :: Nil => diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala index a7c200eae3a0..05d92827a7f9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ImportTests.scala @@ -1,13 +1,12 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.passes.Defines -import io.joern.rubysrc2cpg.passes.Defines.Main +import io.joern.rubysrc2cpg.passes.Defines.{Initialize, Main} import io.joern.rubysrc2cpg.passes.GlobalTypes.{builtinPrefix, kernelPrefix} import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.nodes.Literal -import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes} import io.shiftleft.semanticcpg.language.* -import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import org.scalatest.Inspectors class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with Inspectors { @@ -62,7 +61,13 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In ) val List(newCall) = - cpg.method.isModule.filename("t1.rb").ast.isCall.methodFullName(".*\\.initialize").methodFullName.l + cpg.method.isModule + .filename("t1.rb") + .ast + .isCall + .dynamicTypeHintFullName + .filter(x => x.startsWith(path) && x.endsWith(Initialize)) + .l newCall should startWith(s"$path.rb:") } } @@ -285,12 +290,13 @@ class ImportTests extends RubyCode2CpgFixture(withPostProcessing = true) with In "resolve calls to builtin functions" in { inside(cpg.call.methodFullName("(pp|csv).*").l) { - case csvParseCall :: csvTableInitCall :: ppCall :: Nil => + case csvParseCall :: ppCall :: Nil => csvParseCall.methodFullName shouldBe "csv.CSV.parse" ppCall.methodFullName shouldBe "pp.PP.pp" - csvTableInitCall.methodFullName shouldBe "csv.CSV.Table.initialize" case xs => fail(s"Expected three calls, got [${xs.code.mkString(",")}] instead") } + + cpg.call(Initialize).dynamicTypeHintFullName.toSet should contain("csv.CSV.Table.initialize") } }