Skip to content

Commit

Permalink
[gosrc2cpg] - Global variable as member nodes of package level TypeDe…
Browse files Browse the repository at this point in the history
…cl (#3756)

Follow up PR to the PR #3734 to model global variables as member nodes of Package level TypeDecl, and its access is converted to `field access` call.
  • Loading branch information
pandurangpatil authored Oct 25, 2023
1 parent 06c7657 commit 860f398
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,22 @@ class AstCreator(val relPathFileName: String, val parserResult: ParserResult, go
scope.pushNewScope(fakeGlobalMethodForFile)
val blockNode_ = blockNode(rootNode, Defines.empty, Defines.anyTypeName)
val methodReturn = methodReturnNode(rootNode, Defines.anyTypeName)
val declsAsts = rootNode.json(ParserKeys.Decls).arr.flatMap(item => astForNode(item)).toList
val declsAsts = rootNode
.json(ParserKeys.Decls)
.arr
.flatMap { item =>
val node = createParserNodeInfo(item)
astForNode(node, true)
}
.toList
methodAstParentStack.pop()
scope.popScope()
methodAst(fakeGlobalMethodForFile, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn)
}

protected def astForNode(nodeInfo: ParserNodeInfo): Seq[Ast] = {
protected def astForNode(nodeInfo: ParserNodeInfo, globalStatements: Boolean = false): Seq[Ast] = {
nodeInfo.node match {
case GenDecl => astForGenDecl(nodeInfo)
case GenDecl => astForGenDecl(nodeInfo, globalStatements)
case FuncDecl => astForFuncDecl(nodeInfo)
case _: BasePrimitive => astForPrimitive(nodeInfo)
case _: BaseExpr => astsForExpression(nodeInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import io.joern.gosrc2cpg.parser.ParserAst.*
import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import io.joern.x2cpg
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, NodeTypes, Operators}
import ujson.Value

import scala.util.{Success, Try}

trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
def astForGenDecl(genDecl: ParserNodeInfo): Seq[Ast] = {
def astForGenDecl(genDecl: ParserNodeInfo, globalStatements: Boolean = false): Seq[Ast] = {
Try(
genDecl
.json(ParserKeys.Specs)
Expand All @@ -24,7 +24,7 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
genDeclNode.node match
case ImportSpec => astForImport(genDeclNode)
case TypeSpec => astForTypeSpec(genDeclNode)
case ValueSpec => astForValueSpec(genDeclNode)
case ValueSpec => astForValueSpec(genDeclNode, globalStatements = globalStatements)
case _ => Seq[Ast]()
}
.toSeq
Expand All @@ -41,7 +41,11 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
Seq(Ast(newImportNode(s"import $importedAsReplacement$importedEntity", importedEntity, importedAs, basicLit)))
}

protected def astForValueSpec(valueSpec: ParserNodeInfo, recordVar: Boolean = false): Seq[Ast] = {
protected def astForValueSpec(
valueSpec: ParserNodeInfo,
recordVar: Boolean = false,
globalStatements: Boolean = false
): Seq[Ast] = {
val typeFullName = Try(valueSpec.json(ParserKeys.Type)) match
case Success(typeJson) =>
val (typeFullName, _, _, _) = processTypeInfo(createParserNodeInfo(typeJson))
Expand All @@ -54,17 +58,29 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
(valueSpec.json(ParserKeys.Names).arr.toList zip valueSpec.json(ParserKeys.Values).arr.toList)
.map { case (lhs, rhs) => (createParserNodeInfo(lhs), createParserNodeInfo(rhs)) }
.map { case (lhsParserNode, rhsParserNode) =>
astForAssignmentCallNode(lhsParserNode, rhsParserNode, typeFullName, valueSpec.code, recordVar)
astForAssignmentCallNode(
lhsParserNode,
rhsParserNode,
typeFullName,
valueSpec.code,
recordVar,
globalStatements
)
}
.unzip
localAsts ++: assCallAsts
if globalStatements then Seq.empty else localAsts ++: assCallAsts
case _ =>
valueSpec
.json(ParserKeys.Names)
.arr
.flatMap { parserNode =>
val localParserNode = createParserNodeInfo(parserNode)
Seq(astForLocalNode(localParserNode, typeFullName, recordVar)) ++: astForNode(localParserNode)
if globalStatements then {
astForGlobalVarAndConstants(typeFullName.getOrElse(Defines.anyTypeName), localParserNode)
Seq.empty
} else {
Seq(astForLocalNode(localParserNode, typeFullName, recordVar)) ++: astForNode(localParserNode)
}
}
.toSeq

Expand All @@ -75,23 +91,43 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
rhsParserNode: ParserNodeInfo,
typeFullName: Option[String],
code: String,
recordVar: Boolean = false
recordVar: Boolean = false,
globalStatements: Boolean = false
): (Ast, Ast) = {
val rhsAst = astForBooleanLiteral(rhsParserNode)
val rhsTypeFullName = typeFullName.getOrElse(getTypeFullNameFromAstNode(rhsAst))
val localAst = astForLocalNode(lhsParserNode, Some(rhsTypeFullName), recordVar)
val lhsAst = astForNode(lhsParserNode)
val arguments = lhsAst ++: rhsAst
val cNode = callNode(
rhsParserNode,
code,
Operators.assignment,
Operators.assignment,
DispatchTypes.STATIC_DISPATCH,
None,
Some(rhsTypeFullName)
if (globalStatements) {
astForGlobalVarAndConstants(rhsTypeFullName, lhsParserNode, Some(rhsAst))
(Ast(), Ast())
} else {
val localAst = astForLocalNode(lhsParserNode, Some(rhsTypeFullName), recordVar)
val lhsAst = astForNode(lhsParserNode)
val arguments = lhsAst ++: rhsAst
val cNode = callNode(
rhsParserNode,
code,
Operators.assignment,
Operators.assignment,
DispatchTypes.STATIC_DISPATCH,
None,
Some(rhsTypeFullName)
)
(callAst(cNode, arguments), localAst)
}
}

private def astForGlobalVarAndConstants(
typeFullName: String,
lhsParserNode: ParserNodeInfo,
rhsAst: Option[Seq[Ast]] = None
): Unit = {
val name = lhsParserNode.json(ParserKeys.Name).str
val memberAst = Ast(
memberNode(lhsParserNode, name, lhsParserNode.code, typeFullName)
.astParentType(NodeTypes.TYPE_DECL)
.astParentFullName(fullyQualifiedPackage)
)
(callAst(cNode, arguments), localAst)
Ast.storeInDiffGraph(memberAst, diffGraph)
}

protected def astForLocalNode(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package io.joern.gosrc2cpg.astcreation

import io.joern.gosrc2cpg.datastructures.GoGlobal
import io.joern.gosrc2cpg.parser.ParserAst.*
import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import io.joern.x2cpg.utils.NodeBuilders.newOperatorCallNode
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.NewCall
import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewFieldIdentifier}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}

import scala.util.{Success, Try}
Expand Down Expand Up @@ -68,14 +70,38 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t
val node = identifierNode(ident, identifierName, ident.code, variableTypeName)
Ast(node).withRefEdge(node, variable)
case _ =>
// TODO: something is wrong here. Refer to SwitchTests -> "be correct for switch case 4"
Ast(identifierNode(ident, identifierName, ident.json(ParserKeys.Name).str, Defines.anyTypeName))
// If its not local node then check if its global member variable of package TypeDecl
Option(GoGlobal.structTypeMemberTypeMapping.get(s"$fullyQualifiedPackage${Defines.dot}$identifierName")) match
case Some(fieldTypeFullName) => astForPackageGlobalFieldAccess(fieldTypeFullName, identifierName, ident)
case _ =>
// TODO: something is wrong here. Refer to SwitchTests -> "be correct for switch case 4"
Ast(identifierNode(ident, identifierName, ident.json(ParserKeys.Name).str, Defines.anyTypeName))
}
} else {
Ast()
}
}

private def astForPackageGlobalFieldAccess(
fieldTypeFullName: String,
identifierName: String,
ident: ParserNodeInfo
): Ast = {
val identifierAsts = Seq(Ast(identifierNode(ident, declaredPackageName, ident.code, fullyQualifiedPackage)))
callAst(
newOperatorCallNode(Operators.fieldAccess, ident.code, Some(fieldTypeFullName), line(ident), column(ident)),
identifierAsts ++: Seq(
Ast(
NewFieldIdentifier()
.canonicalName(identifierName)
.lineNumber(line(ident))
.columnNumber(column(ident))
.code(identifierName)
)
)
)
}

protected def getTypeOfToken(basicLit: ParserNodeInfo): String = {
// TODO need to add more primitive types
Try(basicLit.json(ParserKeys.Kind).str match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import io.joern.gosrc2cpg.datastructures.GoGlobal
import io.joern.gosrc2cpg.parser.ParserAst.*
import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import io.joern.x2cpg
import io.joern.x2cpg.utils.NodeBuilders.newOperatorCallNode
import io.joern.x2cpg.utils.NodeBuilders.{newFieldIdentifierNode, newOperatorCallNode}
import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.NewFieldIdentifier
Expand Down Expand Up @@ -35,7 +35,7 @@ trait AstForTypeDeclCreator(implicit withSchemaValidation: ValidationMode) { thi
val fieldNodeInfo = createParserNodeInfo(fieldInfo)
val fieldName = fieldNodeInfo.json(ParserKeys.Name).str
GoGlobal.recordStructTypeMemberType(typeDeclFullName + Defines.dot + fieldName, typeFullName)
Ast(memberNode(typeInfo, fieldName, fieldNodeInfo.code, typeFullName, Seq()))
Ast(memberNode(typeInfo, fieldName, fieldNodeInfo.code, typeFullName))
})
})
.toSeq
Expand Down Expand Up @@ -79,14 +79,10 @@ trait AstForTypeDeclCreator(implicit withSchemaValidation: ValidationMode) { thi
protected def astForFieldAccess(info: ParserNodeInfo): Seq[Ast] = {
val (identifierAsts, fieldTypeFullName) = processReceiver(info)
val fieldIdentifier = info.json(ParserKeys.Sel)(ParserKeys.Name).str
val fieldIdentifierNode = NewFieldIdentifier()
.canonicalName(fieldIdentifier)
.lineNumber(line(info))
.columnNumber(column(info))
.code(fieldIdentifier)
val fieldIdAst = Ast(fieldIdentifierNode)
val callNode =
newOperatorCallNode(Operators.fieldAccess, info.code, Some(fieldTypeFullName), line(info), column(info))
Seq(callAst(callNode, identifierAsts ++ Seq(fieldIdAst)))
Seq(
callAst(callNode, identifierAsts ++ Seq(Ast(newFieldIdentifierNode(fieldIdentifier, line(info), column(info)))))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trait CacheBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCre
val diffGraph = new DiffGraphBuilder
try {

cpgOpt.map(_ => {
cpgOpt.map { _ =>
// We don't want to process this part when third party dependencies are being processed.
val result = GoGlobal.recordAliasToNamespaceMapping(declaredPackageName, fullyQualifiedPackage)
if (result == null) {
Expand All @@ -28,8 +28,7 @@ trait CacheBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCre
val ast = astForPackage(rootNode)
Ast.storeInDiffGraph(ast, diffGraph)
}
})

}
findAndProcess(parserResult.json)
processPackageLevelGolbalVaraiblesAndConstants(parserResult.json)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package io.joern.gosrc2cpg.datastructures
import io.joern.gosrc2cpg.astcreation.Defines
import io.joern.x2cpg.datastructures.Global

import java.util.Map
import java.util.concurrent.ConcurrentHashMap
import scala.jdk.CollectionConverters.EnumerationHasAsScala

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class MethodAndTypeCacheBuilderPass(cpgOpt: Option[Cpg], astFiles: List[String],
val allResults: Future[List[(AstCreator, DiffGraphBuilder)]] = Future.sequence(futures)
val results = Await.result(allResults, Duration.Inf)
val (astCreators, diffGraphs) = results.unzip
cpgOpt.map(cpg => {
diffGraphs.foreach(diffGraph => {
cpgOpt.map { cpg =>
diffGraphs.foreach { diffGraph =>
overflowdb.BatchedUpdate
.applyDiff(cpg.graph, diffGraph, null, null)
.transitiveModifications()
})
})
}
}
astCreators
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,11 @@ class ArraysAndMapTests extends GoCodeToCpgSuite {
|func main() {
|}
|""".stripMargin)
"check LOCAL node" in {
cpg.local("a").size shouldBe 1
val List(x) = cpg.local("a").l
x.typeFullName shouldBe "[]int"
}

"Check IDENTIFIER node" in {
cpg.identifier("a").size shouldBe 1
val List(x) = cpg.identifier("a").l
x.typeFullName shouldBe "[]int"
"check Global member node" in {
val List(x) = cpg.typeDecl("main").l
val List(a) = x.member.l
a.name shouldBe "a"
a.typeFullName shouldBe "[]int"
}
}

Expand All @@ -76,27 +71,23 @@ class ArraysAndMapTests extends GoCodeToCpgSuite {
|}
|""".stripMargin)

"check LOCAL node" in {
cpg.local("a").size shouldBe 1
val List(x) = cpg.local("a").l
x.typeFullName shouldBe "[]int"
}

"Check IDENTIFIER node" in {
cpg.identifier("a").size shouldBe 1
val List(x) = cpg.identifier("a").l
x.typeFullName shouldBe "[]int"
"check Global Member node" in {
val List(x) = cpg.typeDecl("main").l
val List(a) = x.member.l
a.name shouldBe "a"
a.typeFullName shouldBe "[]int"
}

"Check Array initializer CALL node" in {
// TODO need to be handled as part of initializer constructor implementation for package TypeDecl
"Check Array initializer CALL node" ignore {
val List(x) = cpg.call(Operators.arrayInitializer).l
x.typeFullName shouldBe "[]int"
val List(arg1: Literal, arg2: Literal) = x.argument.l: @unchecked
arg1.code shouldBe "1"
arg2.code shouldBe "2"
}

"Check assignment call node" in {
"Check assignment call node" ignore {
val List(assignmentCallNode) = cpg.call(Operators.assignment).l
assignmentCallNode.typeFullName shouldBe "[]int"
val List(arg1: Identifier, arg2: Call) = assignmentCallNode.argument.l: @unchecked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,24 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
|}
|""".stripMargin)

"Check LOCAL node" in {
val List(a, b) = cpg.local.l
"Check package Type Decl" in {
val List(x) = cpg.typeDecl("main").l
x.fullName shouldBe "main"
}

"Traversal from package type decl to global variable member nodes" in {
val List(x) = cpg.typeDecl("main").l
val List(a, b) = x.member.l
a.name shouldBe "FooConst"
a.typeFullName shouldBe "string"
b.name shouldBe "BarVar"
b.typeFullName shouldBe "int"
}

"Be correct for Field Access CALL Node for Global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
}
}

"Var defined(with type mentioned) in one package used in another package" should {
Expand Down
Loading

0 comments on commit 860f398

Please sign in to comment.