From 9a197b43f691c06fecee66c0b83aac8948405a12 Mon Sep 17 00:00:00 2001 From: Johannes Coetzee Date: Thu, 29 Feb 2024 18:26:30 +0100 Subject: [PATCH] [javasrc] Fix JavaParser -> standard name conversion for nested classes (#4252) * Add unit tests showing some failing cases * Fix incorrect handling of nested type names --- .../noncaching/JdkJarTypeSolver.scala | 56 +++++-------- .../javasrc2cpg/querying/CallTests.scala | 82 +++++++++++++++++++ 2 files changed, 102 insertions(+), 36 deletions(-) diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala index 4ac1bc441110..92672b72d02d 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala @@ -44,15 +44,13 @@ class JdkJarTypeSolver(classPool: NonCachingClassPool, knownPackagePrefixes: Set } private def lookupType(javaParserName: String): SymbolReference[ResolvedReferenceTypeDeclaration] = { - val name = convertJavaParserNameToStandard(javaParserName) - Try(classPool.get(name)) match { - case Success(ctClass) => + possibleStandardNamesForJavaParser(javaParserName).iterator + .map(name => Try(classPool.get(name))) + .collectFirst { case Success(ctClass) => val refType = ctClassToRefType(ctClass) refTypeToSymbolReference(refType) - - case Failure(e) => - SymbolReference.unsolved() - } + } + .getOrElse(SymbolReference.unsolved()) } override def solveType(name: String): ResolvedReferenceTypeDeclaration = { @@ -184,37 +182,23 @@ object JdkJarTypeSolver { packagePrefixForJarEntry(entryName.stripPrefix(JmodClassPrefix)) } - /** A name is assumed to contain at least one subclass (e.g. ...Foo$Bar) if the last name part starts with a digit, or - * if the last 2 name parts start with capital letters. This heuristic is based on the class name format in the JDK - * jars, where names with subclasses have one of the forms: - * - java.lang.ClassLoader$2 - * - java.lang.ClassLoader$NativeLibrary - * - java.lang.ClassLoader$NativeLibrary$Unloader - */ - private def namePartsContainSubclass(nameParts: Array[String]): Boolean = { - nameParts.takeRight(2) match { - case Array() => false - - case Array(singlePart) => false - - case Array(secondLast, last) => - last.head.isDigit || (secondLast.head.isUpper && last.head.isUpper) - } - } - - /** JavaParser replaces the `$` in nested class names with a `.`. This method converts the JavaParser names to the - * standard format by replacing the `.` between name parts that start with a capital letter or a digit with a `$` - * since the jdk classes follow the standard practice of capitalising the first letter in class names but not package - * names. + /** JavaParser replaces the `$` in nested class names with a `.`. This means that we cannot know what the standard + * type full name is for JavaParser names with multiple parts, so this method returns all possibilities, for example + * for a.b.Foo.Bar, it will return: + * - a.b.Foo.Bar + * - a.b.Foo$Bar + * - a.b$Foo$Bar + * - a$b$Foo$Bar */ - def convertJavaParserNameToStandard(className: String): String = { - className.split(".") match { - case nameParts if namePartsContainSubclass(nameParts) => - val (packagePrefix, classNames) = nameParts.partition(_.head.isLower) - s"${packagePrefix.mkString(".")}.${classNames.mkString("$")}" + def possibleStandardNamesForJavaParser(javaParserName: String): List[String] = { + val nameParts = javaParserName.split('.') + nameParts.indices.reverse.map { packageLength => + val packageName = nameParts.take(packageLength).mkString(".") + val className = nameParts.drop(packageLength).mkString("$") - case _ => className - } + val packagePrefix = if (packageLength > 0) s"$packageName." else "" + s"$packagePrefix$className" + }.toList } } diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala index b81bbeae6f4b..79553ef95870 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/CallTests.scala @@ -11,6 +11,88 @@ import overflowdb.traversal.jIteratortoTraversal import overflowdb.traversal.toNodeTraversal class NewCallTests extends JavaSrcCode2CpgFixture { + "calls to imported methods" when { + "they are static methods imported from java.lang.* should be resolved" in { + val cpg = code(""" + |class Test { + | public void test() { + | String.valueOf(true); + | } + |} + | + |""".stripMargin) + + cpg.call.name("valueOf").methodFullName.l shouldBe List("java.lang.String.valueOf:java.lang.String(boolean)") + } + + "they are instance methods imported from java.lang.* should be resolved" in { + val cpg = code(""" + |class Test { + | public void test(String s) { + | s.length(); + | } + |} + | + |""".stripMargin) + + cpg.call.name("length").methodFullName.l shouldBe List("java.lang.String.length:int()") + } + + "they are calls to instance methods from java imports should be resolved" in { + val cpg = code(""" + |import java.util.Base64; + | + |class Test { + | public void test(Base64.Decoder decoder, String src) { + | decoder.decode(src); + | } + |} + | + |""".stripMargin) + cpg.call.name("decode").methodFullName.l shouldBe List("java.util.Base64$Decoder.decode:byte[](java.lang.String)") + } + + "they are calls to static methods from java imports should be resolved" in { + val cpg = code(""" + |import java.util.Base64; + | + |class Foo { + | void test() { + | Base64.getDecoder(); + | } + |} + |""".stripMargin) + + cpg.call.name("getDecoder").methodFullName.l shouldBe List( + "java.util.Base64.getDecoder:java.util.Base64$Decoder()" + ) + } + } + + "calls to static methods in other files should be resolved" in { + val cpg = code(""" + |package foo; + | + |class Foo { + | public static String foo() { + | return "FOO"; + | } + |} + |""".stripMargin) + .moreCode(""" + |package bar; + | + |import foo.Foo; + | + |class Bar { + | void test() { + | Foo.foo(); + | } + |} + |""".stripMargin) + + cpg.call.name("foo").methodFullName.l shouldBe List("foo.Foo.foo:java.lang.String()") + } "calls with unresolved receivers should have the correct fullnames" in { val cpg = code("""