diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala index 19365aada17a..d1225cd11519 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/FieldAccessTests.scala @@ -2,7 +2,7 @@ package io.joern.javasrc2cpg.querying import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture import io.shiftleft.codepropertygraph.generated.nodes.{Call, FieldIdentifier, Identifier} -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* class FieldAccessTests extends JavaSrcCode2CpgFixture { @@ -66,4 +66,20 @@ class FieldAccessTests extends JavaSrcCode2CpgFixture { identifier.typeFullName shouldBe "Foo" fieldIdentifier.canonicalName shouldBe "value" } + + "should link to the referencing static member" in { + val List(access: Call) = cpg.method(".*foo.*").call(".*fieldAccess").l + access.referencedMember.name.head shouldBe "MAX_VALUE" + } + + "should link to the referencing dynamic member on the RHS of assignments" in { + val List(access: Call) = cpg.method(".*bar.*").call(".*fieldAccess").l + access.referencedMember.name.head shouldBe "value" + } + + "should link to the referencing dynamic member on the LHS of assignments" in { + val List(access: Call) = cpg.method(".*baz.*").call(".*fieldAccess").l + access.referencedMember.name.head shouldBe "value" + } + } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala index 8a388434a22f..b06130b56145 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/layers/TypeRelations.scala @@ -1,16 +1,17 @@ package io.joern.x2cpg.layers +import io.joern.x2cpg.passes.typerelations.{AliasLinkerPass, FieldAccessLinkerPass, TypeHierarchyPass} import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.passes.CpgPassBase import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} -import io.joern.x2cpg.passes.typerelations.{AliasLinkerPass, TypeHierarchyPass} object TypeRelations { val overlayName: String = "typerel" val description: String = "Type relations layer (hierarchy and aliases)" def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg)) + def passes(cpg: Cpg): Iterator[CpgPassBase] = + Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg), new FieldAccessLinkerPass(cpg)) } class TypeRelations extends LayerCreator { diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala new file mode 100644 index 000000000000..44773d4731be --- /dev/null +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/typerelations/FieldAccessLinkerPass.scala @@ -0,0 +1,88 @@ +package io.joern.x2cpg.passes.typerelations + +import io.joern.x2cpg.passes.frontend.Dereference +import io.joern.x2cpg.utils.LinkingUtil +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Member, StoredNode} +import io.shiftleft.codepropertygraph.generated.{Cpg, EdgeTypes, NodeTypes, PropertyNames} +import io.shiftleft.passes.CpgPass +import io.shiftleft.semanticcpg.language.* +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes +import io.shiftleft.semanticcpg.utils.MemberAccess +import org.slf4j.LoggerFactory + +import scala.jdk.CollectionConverters.* + +/** Links field access calls to the field they are accessing to enable the `cpg.fieldAccess.referencedMember` step. + */ +class FieldAccessLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil { + + private val logger = LoggerFactory.getLogger(getClass) + private val DOT = "." + + override def run(dstGraph: DiffGraphBuilder): Unit = { + linkToMultiple( + cpg, + srcLabels = List(NodeTypes.CALL), + dstNodeLabel = NodeTypes.MEMBER, + edgeType = EdgeTypes.REF, + dstNodeMap = typeDeclMemberToNode(cpg, _), + getDstFullNames = (call: Call) => dstMemberFullNames(call), + dstFullNameKey = PropertyNames.NAME, + dstGraph + ) + } + + private def dstMemberFullNames(call: Call): Seq[String] = { + if (MemberAccess.isFieldAccess(call.name)) { + val fieldAccess = call.asInstanceOf[OpNodes.FieldAccess] + fieldAccess.argumentOption(1) match + case Some(baseNode) => + fieldAccess.fieldIdentifier.canonicalName.headOption match + case Some(fieldName) => + baseNode.evalType.map(x => s"$x$DOT$fieldName").toSeq + case None => + logger.warn(s"Field access ${fieldAccess.code} has no field identifier") + Seq.empty + case None => + logger.warn(s"Field access ${fieldAccess.code} has no base node") + Seq.empty + } else { + Seq.empty + } + } + + private def typeDeclMemberToNode(cpg: Cpg, fieldFullName: String): Option[Member] = { + val (typeFullName, fieldName) = fieldFullName.splitAt(fieldFullName.lastIndexOf(DOT)) + typeDeclFullNameToNode(cpg, typeFullName).member.nameExact(fieldName.stripPrefix(DOT)).headOption + } + + // This is overridden to avoid the step that sets the `dstFullNameKey` property. + override def linkToMultiple[SRC_NODE_TYPE <: StoredNode]( + cpg: Cpg, + srcLabels: List[String], + dstNodeLabel: String, + edgeType: String, + dstNodeMap: String => Option[StoredNode], + getDstFullNames: SRC_NODE_TYPE => Iterable[String], + dstFullNameKey: String, + dstGraph: DiffGraphBuilder + ): Unit = { + val dereference = Dereference(cpg) + cpg.graph.nodes(srcLabels: _*).asScala.cast[SRC_NODE_TYPE].filterNot(_.outE(edgeType).hasNext).foreach { srcNode => + if (!srcNode.outE(edgeType).hasNext) { + getDstFullNames(srcNode).foreach { dstFullName => + val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) + dstNodeMap(dereferenceDstFullName) match { + case Some(dstNode) => + dstGraph.addEdge(srcNode, dstNode, edgeType) + case None if dstNodeMap(dstFullName).isDefined => + dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) + case None => + logFailedDstLookup(edgeType, srcNode.label, srcNode.id.toString, dstNodeLabel, dereferenceDstFullName) + } + } + } + } + } + +}