Skip to content

Commit

Permalink
[kotlin2cpg] Refactorings and clean-ups (#4534)
Browse files Browse the repository at this point in the history
Mostly for readability (extracted code into smaller units, map -> flatten into flatMap etc.)
  • Loading branch information
max-leuthaeuser authored May 6, 2024
1 parent bd4ed84 commit 0438342
Show file tree
Hide file tree
Showing 30 changed files with 459 additions and 405 deletions.
3 changes: 0 additions & 3 deletions joern-cli/frontends/kotlin2cpg/.scalafmt.conf

This file was deleted.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,39 @@ package io.joern.kotlin2cpg.ast

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.KtFileWithMeta
import io.joern.kotlin2cpg.ast.Nodes.{namespaceBlockNode, operatorCallNode}
import io.joern.kotlin2cpg.types.{TypeConstants, TypeInfoProvider, TypeRenderer}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.*
import io.shiftleft.passes.IntervalKeyPool
import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, Defines, ValidationMode}
import io.joern.kotlin2cpg.ast.Nodes.namespaceBlockNode
import io.joern.kotlin2cpg.ast.Nodes.operatorCallNode
import io.joern.kotlin2cpg.datastructures.Scope
import io.joern.kotlin2cpg.types.TypeConstants
import io.joern.kotlin2cpg.types.TypeInfoProvider
import io.joern.kotlin2cpg.types.TypeRenderer
import io.joern.x2cpg.Ast
import io.joern.x2cpg.AstCreatorBase
import io.joern.x2cpg.AstNodeBuilder
import io.joern.x2cpg.Defines
import io.joern.x2cpg.ValidationMode
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.datastructures.Stack.*
import io.joern.kotlin2cpg.datastructures.Scope
import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode
import io.shiftleft.codepropertygraph.generated.*
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.passes.IntervalKeyPool
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.{DescriptorVisibilities, DescriptorVisibility}
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.DescriptorVisibility
import org.jetbrains.kotlin.lexer.KtToken
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.lexer.{KtToken, KtTokens}
import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import overflowdb.BatchedUpdate.DiffGraphBuilder

import java.io.PrintWriter
import java.io.StringWriter
import scala.annotation.tailrec
import scala.collection.mutable
import io.shiftleft.semanticcpg.language.*

import java.io.{PrintWriter, StringWriter}
import scala.jdk.CollectionConverters.*

case class BindingInfo(node: NewBinding, edgeMeta: Seq[(NewNode, NewNode, String)])
Expand Down Expand Up @@ -52,7 +62,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
protected val relativizedPath: String = fileWithMeta.relativizedPath

protected val scope: Scope[String, DeclarationNew, NewNode] = new Scope()
protected val debugScope = mutable.Stack.empty[KtDeclaration]
protected val debugScope: mutable.Stack[KtDeclaration] = mutable.Stack.empty[KtDeclaration]

def createAst(): DiffGraphBuilder = {
implicit val typeInfoProvider: TypeInfoProvider = xTypeInfoProvider
Expand Down Expand Up @@ -130,7 +140,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
else node.importedEntity.getOrElse("")
}

protected def storeInDiffGraph(ast: Ast): Unit = {
private def storeInDiffGraph(ast: Ast): Unit = {
Ast.storeInDiffGraph(ast, diffGraph)

for {
Expand Down Expand Up @@ -273,7 +283,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
}
}

def astForFile(fileWithMeta: KtFileWithMeta)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
private def astForFile(fileWithMeta: KtFileWithMeta)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val ktFile = fileWithMeta.f

val importDirectives = ktFile.getImportList.getImports.asScala
Expand Down Expand Up @@ -344,7 +354,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
case p: KtProperty => astsForProperty(p)
case unhandled =>
logger.error(
s"Unknown declaration type encountered in this file `${relativizedPath}` with text `${unhandled.getText}` and class `${unhandled.getClass}`!"
s"Unknown declaration type encountered in this file `$relativizedPath` with text `${unhandled.getText}` and class `${unhandled.getClass}`!"
)
Seq()
}
Expand All @@ -355,7 +365,7 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
val printWriter = new PrintWriter(stringWriter)
exception.printStackTrace(printWriter)
logger.warn(
s"Caught exception while processing decl in this file `${relativizedPath}`:\n$declText\n${stringWriter.toString}"
s"Caught exception while processing decl in this file `$relativizedPath`:\n$declText\n${stringWriter.toString}"
)
Seq()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@ import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.ast.Nodes.operatorCallNode
import io.joern.kotlin2cpg.psi.PsiUtils
import io.joern.kotlin2cpg.psi.PsiUtils.nonUnderscoreDestructuringEntries
import io.joern.kotlin2cpg.types.{AnonymousObjectContext, TypeConstants, TypeInfoProvider}
import io.joern.kotlin2cpg.types.AnonymousObjectContext
import io.joern.kotlin2cpg.types.TypeConstants
import io.joern.kotlin2cpg.types.TypeInfoProvider
import io.joern.x2cpg.Ast
import io.joern.x2cpg.Defines
import io.joern.x2cpg.ValidationMode
import io.joern.x2cpg.utils.NodeBuilders
import io.joern.x2cpg.utils.NodeBuilders.{newBindingNode, newIdentifierNode, newMethodReturnNode}
import io.joern.x2cpg.{Ast, AstNodeBuilder, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewCall, NewMethod, NewTypeDecl}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators}
import io.joern.x2cpg.utils.NodeBuilders.newBindingNode
import io.joern.x2cpg.utils.NodeBuilders.newIdentifierNode
import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.EdgeTypes
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
import io.shiftleft.codepropertygraph.generated.nodes.NewCall
import io.shiftleft.codepropertygraph.generated.nodes.NewMethod
import io.shiftleft.codepropertygraph.generated.nodes.NewTypeDecl
import io.shiftleft.semanticcpg.language.*
import org.jetbrains.kotlin.psi.*

Expand Down Expand Up @@ -150,8 +161,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
classDeclarations.toSeq
.collectAll[KtClassOrObject]
.filterNot(typeInfoProvider.isCompanionObject)
.map(astsForDeclaration(_))
.flatten
.flatMap(astsForDeclaration(_))

val classFunctions = Option(ktClass.getBody)
.map(_.getFunctions.asScala.collect { case f: KtNamedFunction => f })
Expand Down Expand Up @@ -228,7 +238,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
}
if (typedInit.isEmpty) {
logger.warn(
s"Unhandled case for destructuring declaration: `${expr.getText}`; type: `${expr.getInitializer.getClass}` in this file `${relativizedPath}`."
s"Unhandled case for destructuring declaration: `${expr.getText}`; type: `${expr.getInitializer.getClass}` in this file `$relativizedPath`."
)
return Seq()
}
Expand Down Expand Up @@ -266,19 +276,22 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
)
val assignmentNode = operatorCallNode(Operators.assignment, s"$tmpName = ${Constants.alloc}", None)
callAst(assignmentNode, List(assignmentLhsAst, Ast(assignmentRhsNode)))
} else if (expr.getInitializer.isInstanceOf[KtArrayAccessExpression]) {
astForArrayAccess(expr.getInitializer.asInstanceOf[KtArrayAccessExpression], None, None)
} else if (expr.getInitializer.isInstanceOf[KtPostfixExpression]) {
astForPostfixExpression(expr.getInitializer.asInstanceOf[KtPostfixExpression], None, None)
} else if (expr.getInitializer.isInstanceOf[KtWhenExpression]) {
astForWhenAsExpression(expr.getInitializer.asInstanceOf[KtWhenExpression], None, None)
} else if (expr.getInitializer.isInstanceOf[KtIfExpression]) {
astForIfAsExpression(expr.getInitializer.asInstanceOf[KtIfExpression], None, None)
} else {
val assignmentNode = operatorCallNode(Operators.assignment, s"$tmpName = ${rhsCall.getText}", None)
val assignmentRhsAst =
astsForExpression(rhsCall, None).headOption.getOrElse(Ast(unknownNode(rhsCall, Constants.empty)))
callAst(assignmentNode, List(assignmentLhsAst, assignmentRhsAst))
expr.getInitializer match {
case accessExpression: KtArrayAccessExpression =>
astForArrayAccess(accessExpression, None, None)
case expression: KtPostfixExpression =>
astForPostfixExpression(expression, None, None)
case expression: KtWhenExpression =>
astForWhenAsExpression(expression, None, None)
case expression: KtIfExpression =>
astForIfAsExpression(expression, None, None)
case _ =>
val assignmentNode = operatorCallNode(Operators.assignment, s"$tmpName = ${rhsCall.getText}", None)
val assignmentRhsAst =
astsForExpression(rhsCall, None).headOption.getOrElse(Ast(unknownNode(rhsCall, Constants.empty)))
callAst(assignmentNode, List(assignmentLhsAst, assignmentRhsAst))
}
}
val tmpAssignmentPrologue = rhsCall match {
case call: KtCallExpression if isCtor =>
Expand Down Expand Up @@ -328,7 +341,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
)(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = {
val typedInit = Option(expr.getInitializer).collect { case e: KtNameReferenceExpression => e }
if (typedInit.isEmpty) {
logger.warn(s"Unhandled case for destructuring declaration: `${expr.getText}` in this file `${relativizedPath}`.")
logger.warn(s"Unhandled case for destructuring declaration: `${expr.getText}` in this file `$relativizedPath`.")
return Seq()
}
val destructuringRHS = typedInit.get
Expand Down Expand Up @@ -624,7 +637,7 @@ trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode) {
}
}

def astForMember(decl: KtDeclaration)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
private def astForMember(decl: KtDeclaration)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val name = Option(decl.getName).getOrElse(TypeConstants.any)
val explicitTypeName = decl.getOriginalElement match {
case p: KtProperty if p.getTypeReference != null => p.getTypeReference.getText
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@ package io.joern.kotlin2cpg.ast

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.ast.Nodes.operatorCallNode
import io.joern.kotlin2cpg.types.{CallKinds, TypeConstants, TypeInfoProvider}
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.joern.kotlin2cpg.types.CallKinds
import io.joern.kotlin2cpg.types.TypeConstants
import io.joern.kotlin2cpg.types.TypeInfoProvider
import io.joern.x2cpg.Ast
import io.joern.x2cpg.Defines
import io.joern.x2cpg.ValidationMode
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.NewMethodRef
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import org.jetbrains.kotlin.lexer.{KtToken, KtTokens}
import org.jetbrains.kotlin.lexer.KtToken
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*

import scala.jdk.CollectionConverters.*
Expand Down Expand Up @@ -63,7 +69,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}
case _ =>
logger.warn(
s"Unhandled operator token type `${opRef.getOperationSignTokenType}` for expression `${expr.getText}` in this file `${relativizedPath}`."
s"Unhandled operator token type `${opRef.getOperationSignTokenType}` for expression `${expr.getText}` in this file `$relativizedPath`."
)
Some(Constants.unknownOperator)
}
Expand Down Expand Up @@ -501,7 +507,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
(asts.dropRight(1), asts.lastOption.getOrElse(Ast(unknownNode(arg.getArgumentExpression, Constants.empty))))
}
val astsForTrails = argAstsWithTrail.map(_._2)
val astsForNonTrails = argAstsWithTrail.map(_._1).flatten
val astsForNonTrails = argAstsWithTrail.flatMap(_._1)

val (fullName, signature) = typeInfoProvider.fullNameWithSignature(expr, (TypeConstants.any, TypeConstants.any))
registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
Expand Down Expand Up @@ -580,9 +586,9 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val typeFullName = registerType(typeInfoProvider.expressionType(expression, TypeConstants.any))
val identifier = identifierNode(arrayExpr, arrayExpr.getText, arrayExpr.getText, typeFullName)
val identifierAst = astWithRefEdgeMaybe(arrayExpr.getText, identifier)
val astsForIndexExpr = expression.getIndexExpressions.asScala.zipWithIndex.map { case (expr, idx) =>
val astsForIndexExpr = expression.getIndexExpressions.asScala.zipWithIndex.flatMap { case (expr, idx) =>
astsForExpression(expr, Option(idx + 1))
}.flatten
}
val callNode =
operatorCallNode(
Operators.indexAccess,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ package io.joern.kotlin2cpg.ast

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.ast.Nodes.modifierNode
import io.joern.kotlin2cpg.types.{TypeConstants, TypeInfoProvider}
import io.joern.kotlin2cpg.types.TypeConstants
import io.joern.kotlin2cpg.types.TypeInfoProvider
import io.joern.x2cpg.Ast
import io.joern.x2cpg.ValidationMode
import io.joern.x2cpg.datastructures.Stack.StackWrapper
import io.joern.x2cpg.utils.NodeBuilders
import io.joern.x2cpg.utils.NodeBuilders.{newBindingNode, newClosureBindingNode, newMethodReturnNode, newModifierNode}
import io.joern.x2cpg.{Ast, AstNodeBuilder, ValidationMode}
import io.joern.x2cpg.utils.NodeBuilders.newBindingNode
import io.joern.x2cpg.utils.NodeBuilders.newClosureBindingNode
import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode
import io.joern.x2cpg.utils.NodeBuilders.newModifierNode
import io.shiftleft.codepropertygraph.generated.EdgeTypes
import io.shiftleft.codepropertygraph.generated.EvaluationStrategies
import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies, ModifierTypes}
import io.shiftleft.semanticcpg.language.*
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.psi.*

Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
package io.joern.kotlin2cpg.ast

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.ast.Nodes.{namespaceBlockNode, operatorCallNode}
import io.joern.kotlin2cpg.types.{TypeConstants, TypeInfoProvider}
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{
NewAnnotation,
NewAnnotationLiteral,
NewImport,
NewLocal,
NewMember,
NewMethodParameterIn
}
import org.jetbrains.kotlin.psi.{
KtAnnotationEntry,
KtClassLiteralExpression,
KtConstantExpression,
KtImportDirective,
KtNameReferenceExpression,
KtStringTemplateExpression,
KtSuperExpression,
KtThisExpression,
KtTypeAlias,
KtTypeReference
}
import io.joern.kotlin2cpg.ast.Nodes.namespaceBlockNode
import io.joern.kotlin2cpg.ast.Nodes.operatorCallNode
import io.joern.kotlin2cpg.types.TypeConstants
import io.joern.kotlin2cpg.types.TypeInfoProvider
import io.joern.x2cpg.Ast
import io.joern.x2cpg.Defines
import io.joern.x2cpg.ValidationMode
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.NewAnnotation
import io.shiftleft.codepropertygraph.generated.nodes.NewAnnotationLiteral
import io.shiftleft.codepropertygraph.generated.nodes.NewImport
import io.shiftleft.codepropertygraph.generated.nodes.NewLocal
import io.shiftleft.codepropertygraph.generated.nodes.NewMember
import io.shiftleft.codepropertygraph.generated.nodes.NewMethodParameterIn
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal

import scala.jdk.CollectionConverters.*
import org.jetbrains.kotlin.psi.KtAnnotationEntry
import org.jetbrains.kotlin.psi.KtClassLiteralExpression
import org.jetbrains.kotlin.psi.KtConstantExpression
import org.jetbrains.kotlin.psi.KtImportDirective
import org.jetbrains.kotlin.psi.KtNameReferenceExpression
import org.jetbrains.kotlin.psi.KtStringTemplateExpression
import org.jetbrains.kotlin.psi.KtSuperExpression
import org.jetbrains.kotlin.psi.KtThisExpression
import org.jetbrains.kotlin.psi.KtTypeAlias
import org.jetbrains.kotlin.psi.KtTypeReference

import scala.annotation.unused
import scala.jdk.CollectionConverters.*
import scala.util.Try

trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
Expand All @@ -45,8 +45,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
val typeFullName = registerType(typeInfoProvider.expressionType(expr, TypeConstants.any))
val node = literalNode(expr, expr.getText, typeFullName)
val annotationAsts = annotations.map(astForAnnotationEntry)
Ast(withArgumentName(withArgumentIndex(node, argIdx), argName))
.withChildren(annotationAsts)
Ast(withArgumentName(withArgumentIndex(node, argIdx), argName)).withChildren(annotationAsts)
}

def astForStringTemplate(
Expand Down Expand Up @@ -86,12 +85,10 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
argName: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val isReferencingMember =
scope.lookupVariable(expr.getIdentifier.getText) match {
case Some(_: NewMember) => true
case _ => false
}

val isReferencingMember = scope.lookupVariable(expr.getIdentifier.getText) match {
case Some(_: NewMember) => true
case _ => false
}
val outAst =
if (typeInfoProvider.isReferenceToClass(expr)) astForNameReferenceToType(expr, argIdx)
else if (isReferencingMember) astForNameReferenceToMember(expr, argIdx)
Expand Down Expand Up @@ -255,7 +252,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) {
val children =
entry.getValueArguments.asScala.flatMap { varg =>
varg.getArgumentExpression match {
case ste: KtStringTemplateExpression if ste.getEntries.size == 1 =>
case ste: KtStringTemplateExpression if ste.getEntries.length == 1 =>
val node = NewAnnotationLiteral().code(ste.getText)
Some(Ast(node))
case ce: KtConstantExpression =>
Expand Down
Loading

0 comments on commit 0438342

Please sign in to comment.