From 8880950e5d316ae55378c9368caf24571af52e87 Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Tue, 19 Mar 2024 15:33:26 +0200 Subject: [PATCH] [c#] Safe Download Dependency Handling & Cleanups * Fixed non-exhaustive matching on string interpolation handling * Wrapped the package version check in a try-catch * Ignore looking to nuget for packages that are internal to a potential c# monolith * Fixed test config loading in fixture --- .../joern/csharpsrc2cpg/CSharpSrc2Cpg.scala | 7 +++-- .../AstForExpressionsCreator.scala | 20 ++++++++----- .../csharpsrc2cpg/passes/DependencyPass.scala | 11 ++++++- .../utils/DependencyDownloader.scala | 30 +++++++++++++------ .../querying/ast/DependencyTests.scala | 1 + .../testfixtures/CSharpCode2CpgFixture.scala | 6 ++-- 6 files changed, 54 insertions(+), 21 deletions(-) diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala index 4f466b6030bf..b8176e79a0a8 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/CSharpSrc2Cpg.scala @@ -19,6 +19,7 @@ import io.shiftleft.passes.CpgPassBase import org.slf4j.LoggerFactory import java.nio.file.Paths +import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} @@ -45,10 +46,12 @@ class CSharpSrc2Cpg extends X2CpgFrontend[Config] { val hash = HashUtil.sha256(astCreators.map(_.parserResult).map(x => Paths.get(x.fullPath))) new MetaDataPass(cpg, Languages.CSHARPSRC, config.inputPath, Option(hash)).createAndApply() - new DependencyPass(cpg, buildFiles(config)).createAndApply() + + val packageIds = mutable.HashSet.empty[String] + new DependencyPass(cpg, buildFiles(config), packageIds.add).createAndApply() // If "download dependencies" is enabled, then fetch dependencies and resolve their symbols for additional types val programSummary = if (config.downloadDependencies) { - DependencyDownloader(cpg, config, internalProgramSummary).download() + DependencyDownloader(cpg, config, internalProgramSummary, packageIds.toSet).download() } else { internalProgramSummary } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index 0915fb5fb893..cdc2e60dc46e 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -445,9 +445,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { .arr .map(createDotNetNodeInfo) .flatMap { expr => - expr.node match + expr.node match { case InterpolatedStringText => astForInterpolatedStringText(expr) case Interpolation => astForInterpolation(expr) + case _ => Nil + } } .toSeq @@ -456,13 +458,17 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { .json(ParserKeys.Contents) .arr .map(createDotNetNodeInfo) - .map { node => - node.node match + .flatMap { node => + node.node match { case InterpolatedStringText => - node - .json(ParserKeys.TextToken)(ParserKeys.Value) - .str // Accessing node.json directly because DotNetNodeInfo contains stripped code, and does not contain braces - case Interpolation => node.json(ParserKeys.MetaData)(ParserKeys.Code).str + Try( + node + .json(ParserKeys.TextToken)(ParserKeys.Value) + .str + ).toOption // Accessing node.json directly because DotNetNodeInfo contains stripped code, and does not contain braces + case Interpolation => Try(node.json(ParserKeys.MetaData)(ParserKeys.Code).str).toOption + case _ => None + } } .mkString("") diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala index 3458281d9d94..26d49dea677e 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/passes/DependencyPass.scala @@ -9,7 +9,8 @@ import org.slf4j.LoggerFactory import scala.util.{Failure, Try} -class DependencyPass(cpg: Cpg, buildFiles: List[String]) extends ForkJoinParallelCpgPass[File](cpg) { +class DependencyPass(cpg: Cpg, buildFiles: List[String], registerPackageId: String => _) + extends ForkJoinParallelCpgPass[File](cpg) { private val logger = LoggerFactory.getLogger(getClass) @@ -18,6 +19,14 @@ class DependencyPass(cpg: Cpg, buildFiles: List[String]) extends ForkJoinParalle override def runOnPart(builder: DiffGraphBuilder, part: File): Unit = { SecureXmlParsing.parseXml(part.contentAsString) match { case Some(xml) if xml.label == "Project" => + // Find packageId (useful for monoliths) + xml.child + .collect { case x if x.label == "PropertyGroup" => x.child } + .flatten + .collect { + case packageId if packageId.label == "PackageId" => registerPackageId(packageId.text) + } + // Register dependencies xml.child .collect { case x if x.label == "ItemGroup" => x.child } .flatten diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala index b267fd677762..51a735304cd3 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/utils/DependencyDownloader.scala @@ -22,7 +22,12 @@ import scala.util.{Failure, Success, Try, Using} * @see * NuGet API */ -class DependencyDownloader(cpg: Cpg, config: Config, internalProgramSummary: CSharpProgramSummary) { +class DependencyDownloader( + cpg: Cpg, + config: Config, + internalProgramSummary: CSharpProgramSummary, + internalPackages: Set[String] = Set.empty +) { private val logger = LoggerFactory.getLogger(getClass) @@ -49,8 +54,8 @@ class DependencyDownloader(cpg: Cpg, config: Config, internalProgramSummary: CSh * true if the dependency is already in the given summary, false if otherwise. */ private def isAlreadySummarized(dependency: Dependency): Boolean = { - // TODO: Implement - false + // TODO: Check internalSummaries too + internalPackages.contains(dependency.name) } private case class NuGetPackageVersions(versions: List[String]) derives ReadWriter @@ -67,12 +72,17 @@ class DependencyDownloader(cpg: Cpg, config: Config, internalProgramSummary: CSh */ private def downloadDependency(targetDir: File, dependency: Dependency): Unit = { - def getVersion(packageName: String): Option[String] = { + def getVersion(packageName: String): Option[String] = Try { Using.resource(URI(s"https://$NUGET_BASE_API_V3/${packageName.toLowerCase}/index.json").toURL.openStream()) { is => Try(read[NuGetPackageVersions](ujson.Readable.fromByteArray(is.readAllBytes()))).toOption .flatMap(_.versions.lastOption) } + } match { + case Failure(_) => + logger.error(s"Unable to resolve `index.json` for `$packageName`, skipping...`") + None + case Success(x) => x } def createUrl(packageType: String, version: String): URL = { @@ -168,12 +178,14 @@ class DependencyDownloader(cpg: Cpg, config: Config, internalProgramSummary: CSh // Move and merge files val libDir = targetDir / "lib" - // Sometimes these dependencies will include DLLs for multiple version of dotnet, we only want one - libDir.listRecursively.filterNot(_.isDirectory).distinctBy(_.name).foreach { f => - f.copyTo(targetDir / f.name) + if (libDir.isDirectory) { + // Sometimes these dependencies will include DLLs for multiple version of dotnet, we only want one + libDir.listRecursively.filterNot(_.isDirectory).distinctBy(_.name).foreach { f => + f.copyTo(targetDir / f.name) + } + // Clean-up lib dir + libDir.delete(swallowIOExceptions = true) } - // Clean-up lib dir - libDir.delete(swallowIOExceptions = true) } /** Given a directory of all the summaries, will produce a summary thereof. diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/DependencyTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/DependencyTests.scala index 90a468de2b55..c07396f9dedc 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/DependencyTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/DependencyTests.scala @@ -1,6 +1,7 @@ package io.joern.csharpsrc2cpg.querying.ast import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.joern.csharpsrc2cpg.Config import io.shiftleft.semanticcpg.language.* class DependencyTests extends CSharpCode2CpgFixture { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala index b49e3d62609d..9a11bcf53932 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/testfixtures/CSharpCode2CpgFixture.scala @@ -63,7 +63,9 @@ class DefaultTestCpgWithCSharp extends DefaultTestCpg with CSharpFrontend with S } override def applyPostProcessingPasses(): Unit = { - CSharpSrc2Cpg.postProcessingPasses(this, config).foreach(_.createAndApply()) + CSharpSrc2Cpg + .postProcessingPasses(this, getConfig().map(_.asInstanceOf[Config]).getOrElse(defaultConfig)) + .foreach(_.createAndApply()) super.applyPostProcessingPasses() } @@ -73,7 +75,7 @@ trait CSharpFrontend extends LanguageFrontend { override val fileSuffix: String = ".cs" - implicit val config: Config = + implicit lazy val defaultConfig: Config = getConfig() .map(_.asInstanceOf[Config]) .getOrElse(Config().withSchemaValidation(ValidationMode.Enabled))