Skip to content

Commit

Permalink
[swiftsrc2cpg] Added support for Enum, Protocol, and Structure (#4023)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-leuthaeuser authored Jan 5, 2024
1 parent a5f3198 commit ba1e065
Show file tree
Hide file tree
Showing 5 changed files with 739 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
}

private def declMembers(
decl: ClassDeclSyntax | ExtensionDeclSyntax,
decl: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax,
withConstructor: Boolean = true
): Seq[DeclSyntax] = {
val memberBlock = decl match {
case c: ClassDeclSyntax => c.memberBlock
case e: ExtensionDeclSyntax => e.memberBlock
case p: ProtocolDeclSyntax => p.memberBlock
case s: StructDeclSyntax => s.memberBlock
case e: EnumDeclSyntax => e.memberBlock
}
val allMembers = memberBlock.members.children.map(_.decl)
if (withConstructor) {
Expand All @@ -48,6 +51,7 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {

private def isInitializedMember(node: DeclSyntax): Boolean = node match {
case v: VariableDeclSyntax => v.bindings.children.exists(c => c.initializer.isDefined || c.accessorBlock.isDefined)
case e: EnumCaseDeclSyntax => e.elements.children.exists(c => c.rawValue.isDefined)
case _ => false
}

Expand Down Expand Up @@ -96,7 +100,7 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
}

private def createFakeConstructor(
node: ClassDeclSyntax | ExtensionDeclSyntax,
node: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax,
methodBlockContent: List[Ast] = List.empty
): AstAndMethod = {
val constructorName = io.joern.x2cpg.Defines.ConstructorMethodName
Expand All @@ -107,10 +111,7 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {

methodAstParentStack.push(methodNode_)

val name = node match {
case c: ClassDeclSyntax => code(c.name)
case e: ExtensionDeclSyntax => code(e.extendedType)
}
val name = typeNameForDeclSyntax(node)
val returnType = calcTypeNameAndFullName(name)._2
val methodReturnNode =
newMethodReturnNode(returnType, dynamicTypeHintFullName = None, line = line(node), column = column(node))
Expand All @@ -134,9 +135,9 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
AstAndMethod(Ast(), methodNode_, bAst)
}

private def astForClassMember(classElement: DeclSyntax, typeDeclNode: NewTypeDecl): Ast = {
val typeFullName = typeNameForDeclSyntax(classElement)
classElement match {
private def astForDeclMember(node: DeclSyntax, typeDeclNode: NewTypeDecl): Ast = {
val typeFullName = typeNameForDeclSyntax(node)
node match {
case d: (AccessorDeclSyntax | InitializerDeclSyntax | DeinitializerDeclSyntax | FunctionDeclSyntax) =>
val function = astForFunctionLike(d).method
val bindingNode = newBindingNode("", "", "")
Expand All @@ -149,12 +150,18 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
ImportDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | MacroDeclSyntax | MacroExpansionDeclSyntax |
OperatorDeclSyntax | PoundSourceLocationSyntax | PrecedenceGroupDeclSyntax | SubscriptDeclSyntax |
TypeAliasDeclSyntax | IfConfigDeclSyntax) =>
val ast = astForNode(classElement)
val ast = astForNode(node)
Ast.storeInDiffGraph(ast, diffGraph)
ast.root.foreach(r => diffGraph.addEdge(typeDeclNode, r, EdgeTypes.AST))
Ast()
case d: EnumCaseDeclSyntax =>
notHandledYet(d)
val ast = astForNode(d)
d.elements.children.foreach { c =>
val cCode = code(c.name)
val memberNode_ = memberNode(c, cCode, cCode, typeFullName)
diffGraph.addEdge(typeDeclNode, memberNode_, EdgeTypes.AST)
}
ast
case d: VariableDeclSyntax =>
val ast = astForNode(d)
d.bindings.children.foreach { c =>
Expand All @@ -167,11 +174,13 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
}
}

private def findDeclConstructor(decl: ClassDeclSyntax | ExtensionDeclSyntax): Option[DeclSyntax] =
private def findDeclConstructor(
decl: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax
): Option[DeclSyntax] =
declMembers(decl).find(isConstructor)

private def createDeclConstructor(
node: ClassDeclSyntax | ExtensionDeclSyntax,
node: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax,
constructorContent: List[Ast],
constructorBlock: Ast = Ast()
): Option[AstAndMethod] =
Expand Down Expand Up @@ -199,15 +208,30 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
node.isInstanceOf[FunctionDeclSyntax] ||
!isInitializedMember(node)

private def astForClassDeclSyntax(node: ClassDeclSyntax): Ast = {
private def astForDeclAttributes(
node: ClassDeclSyntax | ProtocolDeclSyntax | VariableDeclSyntax | StructDeclSyntax | EnumDeclSyntax
): Seq[Ast] = {
node match {
case c: ClassDeclSyntax => c.attributes.children.map(astForNode)
case p: ProtocolDeclSyntax => p.attributes.children.map(astForNode)
case v: VariableDeclSyntax => v.attributes.children.map(astForNode)
case s: StructDeclSyntax => s.attributes.children.map(astForNode)
case e: EnumDeclSyntax => e.attributes.children.map(astForNode)
}
}

private def astForTypeDeclSyntax(
node: ClassDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax
): Ast = {
// TODO:
// - handle genericParameterClause
// - handle genericWhereClause
val attributes = node.attributes.children.map(astForNode)
val attributes = astForDeclAttributes(node)
val modifiers = modifiersForDecl(node)
val inherits = inheritsFrom(node)

val (typeName, typeFullName) = calcTypeNameAndFullName(code(node.name))
val name = typeNameForDeclSyntax(node)
val (typeName, typeFullName) = calcTypeNameAndFullName(name)
val existingTypeDecl = global.seenTypeDecls.keys().asScala.find(_.name == typeName)

if (existingTypeDecl.isEmpty) {
Expand Down Expand Up @@ -252,18 +276,18 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
// adding all other members and retrieving their initialization calls
val memberInitCalls = allClassMembers
.filter(m => !isStaticMember(m) && isInitializedMember(m))
.map(m => astForClassMember(m, typeDeclNode_))
.map(m => astForDeclMember(m, typeDeclNode_))

val constructor = createDeclConstructor(node, memberInitCalls)

// adding all class methods / functions and uninitialized, non-static members
allClassMembers
.filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member))
.foreach(m => astForClassMember(m, typeDeclNode_))
.foreach(m => astForDeclMember(m, typeDeclNode_))

// adding all static members and retrieving their initialization calls
val staticMemberInitCalls =
allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_))
allClassMembers.filter(isStaticMember).map(m => astForDeclMember(m, typeDeclNode_))

methodAstParentStack.pop()
dynamicInstanceTypeStack.pop()
Expand Down Expand Up @@ -307,18 +331,18 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
// adding all other members and retrieving their initialization calls
val memberInitCalls = allClassMembers
.filter(m => !isStaticMember(m) && isInitializedMember(m))
.map(m => astForClassMember(m, typeDeclNode_))
.map(m => astForDeclMember(m, typeDeclNode_))

createDeclConstructor(node, memberInitCalls, constructorBlock)

// adding all class methods / functions and uninitialized, non-static members
allClassMembers
.filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member))
.foreach(m => astForClassMember(m, typeDeclNode_))
.foreach(m => astForDeclMember(m, typeDeclNode_))

// adding all static members and retrieving their initialization calls
val staticMemberInitCalls =
allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_))
allClassMembers.filter(isStaticMember).map(m => astForDeclMember(m, typeDeclNode_))

methodAstParentStack.pop()
dynamicInstanceTypeStack.pop()
Expand Down Expand Up @@ -349,13 +373,58 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {

private def astForDeinitializerDeclSyntax(node: DeinitializerDeclSyntax): Ast = notHandledYet(node)
private def astForEditorPlaceholderDeclSyntax(node: EditorPlaceholderDeclSyntax): Ast = notHandledYet(node)
private def astForEnumCaseDeclSyntax(node: EnumCaseDeclSyntax): Ast = notHandledYet(node)
private def astForEnumDeclSyntax(node: EnumDeclSyntax): Ast = notHandledYet(node)

private def inheritsFrom(node: ClassDeclSyntax | ExtensionDeclSyntax): Seq[String] = {
private def astForEnumCaseDeclSyntax(node: EnumCaseDeclSyntax): Ast = {
val attributeAsts = node.attributes.children.map(astForNode)
val modifiers = modifiersForDecl(node)
val scopeType = BlockScope

val bindingAsts = node.elements.children.map { binding =>
val name = code(binding.name)
val nLocalNode = localNode(binding, name, name, Defines.Any).order(0)
scope.addVariable(name, nLocalNode, scopeType)
diffGraph.addEdge(localAstParentStack.head, nLocalNode, EdgeTypes.AST)

val initAsts = binding.rawValue.map(astForNode).toList
if (initAsts.isEmpty) {
Ast()
} else {
val patternAst = astForNode(binding.name)
modifiers.foreach { mod =>
patternAst.root.foreach { r => diffGraph.addEdge(r, mod, EdgeTypes.AST) }
}
attributeAsts.foreach { attrAst =>
patternAst.root.foreach { r => attrAst.root.foreach { attr => diffGraph.addEdge(r, attr, EdgeTypes.AST) } }
}
createAssignmentCallAst(
patternAst,
initAsts.head,
code(binding).stripSuffix(","),
line = line(binding),
column = column(binding)
)
}
}

bindingAsts match {
case Nil => Ast()
case head :: Nil => head
case _ =>
val block = blockNode(node, code(node), Defines.Any)
setArgumentIndices(bindingAsts)
blockAst(block, bindingAsts.toList)
}
}

private def inheritsFrom(
node: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax
): Seq[String] = {
val clause = node match {
case c: ClassDeclSyntax => c.inheritanceClause
case e: ExtensionDeclSyntax => e.inheritanceClause
case p: ProtocolDeclSyntax => p.inheritanceClause
case s: StructDeclSyntax => s.inheritanceClause
case e: EnumDeclSyntax => e.inheritanceClause
}
val inheritsFrom = clause match {
case Some(value) => value.inheritedTypes.children.map(c => code(c.`type`))
Expand All @@ -376,7 +445,8 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
val modifiers = modifiersForDecl(node)
val inherits = inheritsFrom(node)

val (typeName, typeFullName) = calcTypeNameAndFullName(code(node.extendedType))
val name = typeNameForDeclSyntax(node)
val (typeName, typeFullName) = calcTypeNameAndFullName(name)
val existingTypeDecl = global.seenTypeDecls.keys().asScala.find(_.name == typeName)

if (existingTypeDecl.isEmpty) {
Expand Down Expand Up @@ -421,18 +491,18 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
// adding all other members and retrieving their initialization calls
val memberInitCalls = allClassMembers
.filter(m => !isStaticMember(m) && isInitializedMember(m))
.map(m => astForClassMember(m, typeDeclNode_))
.map(m => astForDeclMember(m, typeDeclNode_))

val constructor = createDeclConstructor(node, memberInitCalls)

// adding all class methods / functions and uninitialized, non-static members
allClassMembers
.filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member))
.foreach(m => astForClassMember(m, typeDeclNode_))
.foreach(m => astForDeclMember(m, typeDeclNode_))

// adding all static members and retrieving their initialization calls
val staticMemberInitCalls =
allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_))
allClassMembers.filter(isStaticMember).map(m => astForDeclMember(m, typeDeclNode_))

methodAstParentStack.pop()
dynamicInstanceTypeStack.pop()
Expand Down Expand Up @@ -480,18 +550,18 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
// adding all other members and retrieving their initialization calls
val memberInitCalls = allClassMembers
.filter(m => !isStaticMember(m) && isInitializedMember(m))
.map(m => astForClassMember(m, typeDeclNode_))
.map(m => astForDeclMember(m, typeDeclNode_))

createDeclConstructor(node, memberInitCalls, constructorBlock)

// adding all class methods / functions and uninitialized, non-static members
allClassMembers
.filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member))
.foreach(m => astForClassMember(m, typeDeclNode_))
.foreach(m => astForDeclMember(m, typeDeclNode_))

// adding all static members and retrieving their initialization calls
val staticMemberInitCalls =
allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_))
allClassMembers.filter(isStaticMember).map(m => astForDeclMember(m, typeDeclNode_))

methodAstParentStack.pop()
dynamicInstanceTypeStack.pop()
Expand Down Expand Up @@ -520,10 +590,17 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
}
}

private def modifiersForDecl(node: ClassDeclSyntax | ExtensionDeclSyntax): Seq[NewModifier] = {
private def modifiersForDecl(
node: ClassDeclSyntax | ExtensionDeclSyntax | ProtocolDeclSyntax | StructDeclSyntax | EnumDeclSyntax |
EnumCaseDeclSyntax
): Seq[NewModifier] = {
val modifierList = node match {
case c: ClassDeclSyntax => c.modifiers.children
case e: ExtensionDeclSyntax => e.modifiers.children
case p: ProtocolDeclSyntax => p.modifiers.children
case s: StructDeclSyntax => s.modifiers.children
case e: EnumDeclSyntax => e.modifiers.children
case ec: EnumCaseDeclSyntax => ec.modifiers.children
}
val modifiers = modifierList.flatMap(c => astForNode(c).root.map(_.asInstanceOf[NewModifier]))
val allModifier = if (modifiers.isEmpty) {
Expand Down Expand Up @@ -699,12 +776,9 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
private def astForInitializerDeclSyntax(node: InitializerDeclSyntax): Ast = notHandledYet(node)
private def astForMacroDeclSyntax(node: MacroDeclSyntax): Ast = notHandledYet(node)
private def astForMacroExpansionDeclSyntax(node: MacroExpansionDeclSyntax): Ast = notHandledYet(node)
private def astForMissingDeclSyntax(node: MissingDeclSyntax): Ast = notHandledYet(node)
private def astForOperatorDeclSyntax(node: OperatorDeclSyntax): Ast = notHandledYet(node)
private def astForPoundSourceLocationSyntax(node: PoundSourceLocationSyntax): Ast = notHandledYet(node)
private def astForPrecedenceGroupDeclSyntax(node: PrecedenceGroupDeclSyntax): Ast = notHandledYet(node)
private def astForProtocolDeclSyntax(node: ProtocolDeclSyntax): Ast = notHandledYet(node)
private def astForStructDeclSyntax(node: StructDeclSyntax): Ast = notHandledYet(node)
private def astForSubscriptDeclSyntax(node: SubscriptDeclSyntax): Ast = notHandledYet(node)
private def astForTypeAliasDeclSyntax(node: TypeAliasDeclSyntax): Ast = notHandledYet(node)

Expand Down Expand Up @@ -781,24 +855,24 @@ trait AstForDeclSyntaxCreator(implicit withSchemaValidation: ValidationMode) {
case node: AccessorDeclSyntax => astForAccessorDeclSyntax(node)
case node: ActorDeclSyntax => astForActorDeclSyntax(node)
case node: AssociatedTypeDeclSyntax => astForAssociatedTypeDeclSyntax(node)
case node: ClassDeclSyntax => astForClassDeclSyntax(node)
case node: ClassDeclSyntax => astForTypeDeclSyntax(node)
case node: DeinitializerDeclSyntax => astForDeinitializerDeclSyntax(node)
case node: EditorPlaceholderDeclSyntax => astForEditorPlaceholderDeclSyntax(node)
case node: EnumCaseDeclSyntax => astForEnumCaseDeclSyntax(node)
case node: EnumDeclSyntax => astForEnumDeclSyntax(node)
case node: EnumDeclSyntax => astForTypeDeclSyntax(node)
case node: ExtensionDeclSyntax => astForExtensionDeclSyntax(node)
case node: FunctionDeclSyntax => astForFunctionDeclSyntax(node)
case node: IfConfigDeclSyntax => astForIfConfigDeclSyntax(node)
case node: ImportDeclSyntax => astForImportDeclSyntax(node)
case node: InitializerDeclSyntax => astForInitializerDeclSyntax(node)
case node: MacroDeclSyntax => astForMacroDeclSyntax(node)
case node: MacroExpansionDeclSyntax => astForMacroExpansionDeclSyntax(node)
case node: MissingDeclSyntax => astForMissingDeclSyntax(node)
case _: MissingDeclSyntax => Ast()
case node: OperatorDeclSyntax => astForOperatorDeclSyntax(node)
case node: PoundSourceLocationSyntax => astForPoundSourceLocationSyntax(node)
case node: PrecedenceGroupDeclSyntax => astForPrecedenceGroupDeclSyntax(node)
case node: ProtocolDeclSyntax => astForProtocolDeclSyntax(node)
case node: StructDeclSyntax => astForStructDeclSyntax(node)
case node: ProtocolDeclSyntax => astForTypeDeclSyntax(node)
case node: StructDeclSyntax => astForTypeDeclSyntax(node)
case node: SubscriptDeclSyntax => astForSubscriptDeclSyntax(node)
case node: TypeAliasDeclSyntax => astForTypeAliasDeclSyntax(node)
case node: VariableDeclSyntax => astForVariableDeclSyntax(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ import io.shiftleft.codepropertygraph.generated.*
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.*

class ExtensionTests extends AbstractPassTest {
class ClassExtensionTests extends AbstractPassTest {

"ExtensionTests" should {
"ClassExtensionTests" should {

"test Class and Extension defined afterwards" in AstFixture("""
|public class A {}
|private class B {
| var b = 0.0
|}
|
|class Foo: Bar { // implicitly internal
|class Foo: Bar { // implicitly internal (private)
| public var a = A()
| private var b = false
| var c = 0.0
Expand Down Expand Up @@ -113,7 +113,7 @@ class ExtensionTests extends AbstractPassTest {
| func someOtherFunc() {}
|}
|
|class Foo: Bar { // implicitly internal
|class Foo: Bar { // implicitly internal (private)
| public var a = A()
| private var b = false
| var c = 0.0
Expand Down
Loading

0 comments on commit ba1e065

Please sign in to comment.