diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala index 9b905ae08e96..36bdac3f9859 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala @@ -387,7 +387,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => private def addArgsToPartialInits(allInitCallNodes: Set[NewCall]): Unit = { scope.enclosingTypeDecl.getInitsToComplete.foreach { - case PartialInit(typeFullName, callAst, receiverAst, args, capturedThis) => + case PartialInit(typeFullName, callAst, receiverAst, args, outerClassAst) => callAst.root match { case Some(initRoot: NewCall) if allInitCallNodes.contains(initRoot) => val usedCaptures = if (scope.enclosingTypeDecl.map(_.typeDecl.fullName).contains(typeFullName)) { @@ -399,7 +399,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => receiverAst.root.foreach(receiver => diffGraph.addEdge(initRoot, receiver, EdgeTypes.RECEIVER)) val capturesAsts = - usedCaptures.filterNot(capturedThis.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map { + usedCaptures.filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map { (usedCapture, index) => val identifier = NewIdentifier() .name(usedCapture.name) @@ -413,23 +413,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => Ast(identifier) } - val capturedThisIdentifier = - Option - .when(usedCaptures.exists(_.name == NameConstants.OuterClass))(capturedThis.map { thisNode => - val identifier = NewIdentifier() - .name(thisNode.name) - .code(thisNode.code) - .typeFullName(thisNode.typeFullName) - .lineNumber(initRoot.lineNumber) - .columnNumber(initRoot.columnNumber) - - diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF) - - Ast(identifier) - }) - .flatten - - (receiverAst :: args ++ capturedThisIdentifier.toList ++ capturesAsts) + (receiverAst :: args ++ outerClassAst.toList ++ capturesAsts) .map { argAst => storeInDiffGraph(argAst) argAst.root diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala index eff4af5a4ad6..9abe8c4e209a 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/expressions/AstForCallExpressionsCreator.scala @@ -244,27 +244,31 @@ trait AstForCallExpressionsCreator { this: AstCreator => line(expr) ) - val isInnerType = anonymousClassBody.isDefined || baseTypeFromScope.exists(_.isInstanceOf[ScopeInnerType]) + val isInnerType = anonymousClassBody.isDefined || baseTypeFromScope.exists( + _.isInstanceOf[ScopeInnerType] + ) || expr.getScope.isPresent - lazy val thisParameter = scope.lookupVariable(NameConstants.This).variableNode val initReceiverAst = assignmentTarget.root.collect { case root: AstNodeNew => assignmentTarget.subTreeCopy(root) }.getOrElse { logger.warn(s"Assignment target ast with no root at $filename:${line(expr)}:${column(expr)}") unknownAst(expr) } - // TODO: This is wrong for chained constructors where the `captured this` is the object created by the previous - // constructor in the chain - val capturedThis = scope.lookupVariable(NameConstants.This) match { - case SimpleVariable(param: ScopeParameter) => Some(param.node) - case _ => None - } + val capturedOuterClassAst = + expr.getScope.toScala.flatMap(astsForExpression(_, ExpectedType.empty).headOption).orElse { + scope.lookupVariable(NameConstants.This) match { + case SimpleVariable(param: ScopeParameter) if !scope.isEnclosingScopeStatic => + val outerClassIdentifier = identifierNode(expr, param.name, param.name, param.typeFullName) + Some(Ast(outerClassIdentifier).withRefEdge(outerClassIdentifier, param.node)) + case _ => None + } + } val initAst = if (isInnerType) { val initCallAst = Ast(initCall) scope.enclosingTypeDecl.foreach( _.registerInitToComplete( - PartialInit(allocNode.typeFullName, initCallAst, initReceiverAst, argumentAsts.toList, capturedThis) + PartialInit(allocNode.typeFullName, initCallAst, initReceiverAst, argumentAsts.toList, capturedOuterClassAst) ) ) initCallAst @@ -353,59 +357,6 @@ trait AstForCallExpressionsCreator { this: AstCreator => .columnNumber(columnNumber) } - private def blockAstForConstructorInvocation( - lineNumber: Option[Integer], - columnNumber: Option[Integer], - allocNode: NewCall, - initNode: NewCall, - args: Seq[Ast], - isInnerType: Boolean - ): Ast = { - val blockNode = NewBlock() - .lineNumber(lineNumber) - .columnNumber(columnNumber) - .typeFullName(allocNode.typeFullName) - - val tempName = "$obj" ++ tempConstCount.toString - tempConstCount += 1 - val identifier = newIdentifierNode(tempName, allocNode.typeFullName) - val identifierAst = Ast(identifier) - - val allocAst = Ast(allocNode) - - val assignmentNode = newOperatorCallNode(Operators.assignment, PropertyDefaults.Code, Some(allocNode.typeFullName)) - - val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst)) - - val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order) - val identifierForInit = identifierWithDefaultOrder.copy - val initCopyWithDefaultOrder = initNode.copy.order(PropertyDefaults.Order) - - val returnAst = Ast(identifierWithDefaultOrder.copy) - - val capturedThis = scope.lookupVariable(NameConstants.This) match { - case SimpleVariable(param: ScopeParameter) => Some(param.node) - case _ => None - } - - val initAst = if (isInnerType) { - val initCallAst = Ast(initCopyWithDefaultOrder) - scope.enclosingTypeDecl.foreach( - _.registerInitToComplete( - PartialInit(allocNode.typeFullName, initCallAst, Ast(identifierForInit), args.toList, capturedThis) - ) - ) - initCallAst - } else { - callAst(initCopyWithDefaultOrder, args, Some(Ast(identifierForInit))) - } - - Ast(blockNode) - .withChild(assignmentAst) - .withChild(initAst) - .withChild(returnAst) - } - private def getArgumentCodeString(args: NodeList[Expression]): String = { args.asScala .map { diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala index fc838d36c044..04f5df7ea105 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/scope/JavaScopeElement.scala @@ -161,7 +161,7 @@ object JavaScopeElement { callAst: Ast, receiverAst: Ast, argsAsts: List[Ast], - capturedThis: Option[NewMethodParameterIn] + outerClassAst: Option[Ast] ) extension (typeDeclScope: Option[TypeDeclScope]) { diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ConstructorInvocationTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ConstructorInvocationTests.scala index c0bbdfdba3c6..1a4294e88a55 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ConstructorInvocationTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/ConstructorInvocationTests.scala @@ -7,17 +7,116 @@ import io.shiftleft.proto.cpg.Cpg.DispatchTypes import io.shiftleft.semanticcpg.language._ class NewConstructorInvocationTests extends JavaSrcCode2CpgFixture { - "constructor init method call" should { - lazy val cpg = code(""" + "inner class constructor invocations" when { + "the receiver is an object creation expression" should { + val cpg = code(""" |class Foo { - | Foo(long aaa) { + | class Bar {} + | + | public static void test() { + | Bar b = new Foo().new Bar(); | } - | static void method() { - | Foo foo = new Foo(1); + |} + |""".stripMargin) + + "have the correct block lowering" in { + inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) => + barInit.methodFullName shouldBe "Foo$Bar.:void()" + + inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Block) => + barReceiver.name shouldBe "b" + barReceiver.typeFullName shouldBe "Foo$Bar" + + inside(outerClass.astChildren.l) { + case List(tmpLocal: Local, tmpAssign: Call, tmpInit: Call, tmpRet: Identifier) => + tmpLocal.name shouldBe "$obj0" + tmpLocal.typeFullName shouldBe "Foo" + + tmpAssign.code shouldBe "$obj0 = new Foo()" + tmpAssign.typeFullName shouldBe "Foo" + tmpAssign.methodFullName shouldBe Operators.assignment + + tmpInit.code shouldBe "new Foo()" + inside(tmpInit.argument.l) { case List(tmpReceiver: Identifier) => + tmpReceiver.name shouldBe "$obj0" + tmpReceiver.typeFullName shouldBe "Foo" + } + } + } + } + } + } + + "the receiver is a variable" should { + val cpg = code(""" + |class Foo { + | class Bar {} + | + | public static void test(Foo f) { + | Bar b = f.new Bar(); | } |} |""".stripMargin) + "have the correct structure" in { + inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) => + barInit.methodFullName shouldBe "Foo$Bar.:void()" + + inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Identifier) => + barReceiver.name shouldBe "b" + barReceiver.typeFullName shouldBe "Foo$Bar" + + outerClass.name shouldBe "f" + outerClass.typeFullName shouldBe "Foo" + outerClass.refsTo.l shouldBe cpg.method.name("test").parameter.name("f").l + } + } + } + } + + "the receiver is a call" should { + val cpg = code(""" + |class Foo { + | class Bar {} + | + | public static Foo foo() { + | return new Foo(); + | } + | + | public static void test() { + | Bar b = foo().new Bar(); + | } + |} + |""".stripMargin) + + "have the correct structure" in { + inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) => + barInit.methodFullName shouldBe "Foo$Bar.:void()" + + inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Call) => + barReceiver.name shouldBe "b" + barReceiver.typeFullName shouldBe "Foo$Bar" + + outerClass.name shouldBe "foo" + outerClass.typeFullName shouldBe "Foo" + outerClass.methodFullName shouldBe "Foo.foo:Foo()" + } + } + } + } + } + + "constructor init method call" should { + val cpg = code(""" + |class Foo { + | Foo(long aaa) { + | } + | static void method() { + | Foo foo = new Foo(1); + | } + |} + |""".stripMargin) + "have correct methodFullName and signature" in { val initCall = cpg.call.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).head initCall.signature shouldBe "void(long)" @@ -26,15 +125,15 @@ class NewConstructorInvocationTests extends JavaSrcCode2CpgFixture { } "a simple single argument constructor" should { - lazy val fooCpg = code(""" - |class Foo { - | int x; - | - | public Foo(int x) { - | this.x = x; - | } - |} - |""".stripMargin) + val fooCpg = code(""" + |class Foo { + | int x; + | + | public Foo(int x) { + | this.x = x; + | } + |} + |""".stripMargin) "create the correct Ast for the constructor" in { fooCpg.typeDecl.name("Foo").method.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).l match { case List(cons: Method) =>