Skip to content

Commit

Permalink
Fix inner class constructors (#4197)
Browse files Browse the repository at this point in the history
  • Loading branch information
johannescoetzee authored Feb 19, 2024
1 parent b4ff8b3 commit 6041d8c
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>

private def addArgsToPartialInits(allInitCallNodes: Set[NewCall]): Unit = {
scope.enclosingTypeDecl.getInitsToComplete.foreach {
case PartialInit(typeFullName, callAst, receiverAst, args, capturedThis) =>
case PartialInit(typeFullName, callAst, receiverAst, args, outerClassAst) =>
callAst.root match {
case Some(initRoot: NewCall) if allInitCallNodes.contains(initRoot) =>
val usedCaptures = if (scope.enclosingTypeDecl.map(_.typeDecl.fullName).contains(typeFullName)) {
Expand All @@ -399,7 +399,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
receiverAst.root.foreach(receiver => diffGraph.addEdge(initRoot, receiver, EdgeTypes.RECEIVER))

val capturesAsts =
usedCaptures.filterNot(capturedThis.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map {
usedCaptures.filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map {
(usedCapture, index) =>
val identifier = NewIdentifier()
.name(usedCapture.name)
Expand All @@ -413,23 +413,7 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
Ast(identifier)
}

val capturedThisIdentifier =
Option
.when(usedCaptures.exists(_.name == NameConstants.OuterClass))(capturedThis.map { thisNode =>
val identifier = NewIdentifier()
.name(thisNode.name)
.code(thisNode.code)
.typeFullName(thisNode.typeFullName)
.lineNumber(initRoot.lineNumber)
.columnNumber(initRoot.columnNumber)

diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF)

Ast(identifier)
})
.flatten

(receiverAst :: args ++ capturedThisIdentifier.toList ++ capturesAsts)
(receiverAst :: args ++ outerClassAst.toList ++ capturesAsts)
.map { argAst =>
storeInDiffGraph(argAst)
argAst.root
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,27 +244,31 @@ trait AstForCallExpressionsCreator { this: AstCreator =>
line(expr)
)

val isInnerType = anonymousClassBody.isDefined || baseTypeFromScope.exists(_.isInstanceOf[ScopeInnerType])
val isInnerType = anonymousClassBody.isDefined || baseTypeFromScope.exists(
_.isInstanceOf[ScopeInnerType]
) || expr.getScope.isPresent

lazy val thisParameter = scope.lookupVariable(NameConstants.This).variableNode
val initReceiverAst =
assignmentTarget.root.collect { case root: AstNodeNew => assignmentTarget.subTreeCopy(root) }.getOrElse {
logger.warn(s"Assignment target ast with no root at $filename:${line(expr)}:${column(expr)}")
unknownAst(expr)
}

// TODO: This is wrong for chained constructors where the `captured this` is the object created by the previous
// constructor in the chain
val capturedThis = scope.lookupVariable(NameConstants.This) match {
case SimpleVariable(param: ScopeParameter) => Some(param.node)
case _ => None
}
val capturedOuterClassAst =
expr.getScope.toScala.flatMap(astsForExpression(_, ExpectedType.empty).headOption).orElse {
scope.lookupVariable(NameConstants.This) match {
case SimpleVariable(param: ScopeParameter) if !scope.isEnclosingScopeStatic =>
val outerClassIdentifier = identifierNode(expr, param.name, param.name, param.typeFullName)
Some(Ast(outerClassIdentifier).withRefEdge(outerClassIdentifier, param.node))
case _ => None
}
}

val initAst = if (isInnerType) {
val initCallAst = Ast(initCall)
scope.enclosingTypeDecl.foreach(
_.registerInitToComplete(
PartialInit(allocNode.typeFullName, initCallAst, initReceiverAst, argumentAsts.toList, capturedThis)
PartialInit(allocNode.typeFullName, initCallAst, initReceiverAst, argumentAsts.toList, capturedOuterClassAst)
)
)
initCallAst
Expand Down Expand Up @@ -353,59 +357,6 @@ trait AstForCallExpressionsCreator { this: AstCreator =>
.columnNumber(columnNumber)
}

private def blockAstForConstructorInvocation(
lineNumber: Option[Integer],
columnNumber: Option[Integer],
allocNode: NewCall,
initNode: NewCall,
args: Seq[Ast],
isInnerType: Boolean
): Ast = {
val blockNode = NewBlock()
.lineNumber(lineNumber)
.columnNumber(columnNumber)
.typeFullName(allocNode.typeFullName)

val tempName = "$obj" ++ tempConstCount.toString
tempConstCount += 1
val identifier = newIdentifierNode(tempName, allocNode.typeFullName)
val identifierAst = Ast(identifier)

val allocAst = Ast(allocNode)

val assignmentNode = newOperatorCallNode(Operators.assignment, PropertyDefaults.Code, Some(allocNode.typeFullName))

val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst))

val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order)
val identifierForInit = identifierWithDefaultOrder.copy
val initCopyWithDefaultOrder = initNode.copy.order(PropertyDefaults.Order)

val returnAst = Ast(identifierWithDefaultOrder.copy)

val capturedThis = scope.lookupVariable(NameConstants.This) match {
case SimpleVariable(param: ScopeParameter) => Some(param.node)
case _ => None
}

val initAst = if (isInnerType) {
val initCallAst = Ast(initCopyWithDefaultOrder)
scope.enclosingTypeDecl.foreach(
_.registerInitToComplete(
PartialInit(allocNode.typeFullName, initCallAst, Ast(identifierForInit), args.toList, capturedThis)
)
)
initCallAst
} else {
callAst(initCopyWithDefaultOrder, args, Some(Ast(identifierForInit)))
}

Ast(blockNode)
.withChild(assignmentAst)
.withChild(initAst)
.withChild(returnAst)
}

private def getArgumentCodeString(args: NodeList[Expression]): String = {
args.asScala
.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ object JavaScopeElement {
callAst: Ast,
receiverAst: Ast,
argsAsts: List[Ast],
capturedThis: Option[NewMethodParameterIn]
outerClassAst: Option[Ast]
)

extension (typeDeclScope: Option[TypeDeclScope]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,116 @@ import io.shiftleft.proto.cpg.Cpg.DispatchTypes
import io.shiftleft.semanticcpg.language._

class NewConstructorInvocationTests extends JavaSrcCode2CpgFixture {
"constructor init method call" should {
lazy val cpg = code("""
"inner class constructor invocations" when {
"the receiver is an object creation expression" should {
val cpg = code("""
|class Foo {
| Foo(long aaa) {
| class Bar {}
|
| public static void test() {
| Bar b = new Foo().new Bar();
| }
| static void method() {
| Foo foo = new Foo(1);
|}
|""".stripMargin)

"have the correct block lowering" in {
inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) =>
barInit.methodFullName shouldBe "Foo$Bar.<init>:void()"

inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Block) =>
barReceiver.name shouldBe "b"
barReceiver.typeFullName shouldBe "Foo$Bar"

inside(outerClass.astChildren.l) {
case List(tmpLocal: Local, tmpAssign: Call, tmpInit: Call, tmpRet: Identifier) =>
tmpLocal.name shouldBe "$obj0"
tmpLocal.typeFullName shouldBe "Foo"

tmpAssign.code shouldBe "$obj0 = new Foo()"
tmpAssign.typeFullName shouldBe "Foo"
tmpAssign.methodFullName shouldBe Operators.assignment

tmpInit.code shouldBe "new Foo()"
inside(tmpInit.argument.l) { case List(tmpReceiver: Identifier) =>
tmpReceiver.name shouldBe "$obj0"
tmpReceiver.typeFullName shouldBe "Foo"
}
}
}
}
}
}

"the receiver is a variable" should {
val cpg = code("""
|class Foo {
| class Bar {}
|
| public static void test(Foo f) {
| Bar b = f.new Bar();
| }
|}
|""".stripMargin)

"have the correct structure" in {
inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) =>
barInit.methodFullName shouldBe "Foo$Bar.<init>:void()"

inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Identifier) =>
barReceiver.name shouldBe "b"
barReceiver.typeFullName shouldBe "Foo$Bar"

outerClass.name shouldBe "f"
outerClass.typeFullName shouldBe "Foo"
outerClass.refsTo.l shouldBe cpg.method.name("test").parameter.name("f").l
}
}
}
}

"the receiver is a call" should {
val cpg = code("""
|class Foo {
| class Bar {}
|
| public static Foo foo() {
| return new Foo();
| }
|
| public static void test() {
| Bar b = foo().new Bar();
| }
|}
|""".stripMargin)

"have the correct structure" in {
inside(cpg.call.methodFullName(".*Bar.*init.*").l) { case List(barInit) =>
barInit.methodFullName shouldBe "Foo$Bar.<init>:void()"

inside(barInit.argument.l) { case List(barReceiver: Identifier, outerClass: Call) =>
barReceiver.name shouldBe "b"
barReceiver.typeFullName shouldBe "Foo$Bar"

outerClass.name shouldBe "foo"
outerClass.typeFullName shouldBe "Foo"
outerClass.methodFullName shouldBe "Foo.foo:Foo()"
}
}
}
}
}

"constructor init method call" should {
val cpg = code("""
|class Foo {
| Foo(long aaa) {
| }
| static void method() {
| Foo foo = new Foo(1);
| }
|}
|""".stripMargin)

"have correct methodFullName and signature" in {
val initCall = cpg.call.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).head
initCall.signature shouldBe "void(long)"
Expand All @@ -26,15 +125,15 @@ class NewConstructorInvocationTests extends JavaSrcCode2CpgFixture {
}

"a simple single argument constructor" should {
lazy val fooCpg = code("""
|class Foo {
| int x;
|
| public Foo(int x) {
| this.x = x;
| }
|}
|""".stripMargin)
val fooCpg = code("""
|class Foo {
| int x;
|
| public Foo(int x) {
| this.x = x;
| }
|}
|""".stripMargin)
"create the correct Ast for the constructor" in {
fooCpg.typeDecl.name("Foo").method.nameExact(io.joern.x2cpg.Defines.ConstructorMethodName).l match {
case List(cons: Method) =>
Expand Down

0 comments on commit 6041d8c

Please sign in to comment.