diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala index b84e252d0511..744f8c695c7b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreatorHelper.scala @@ -1,12 +1,11 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.GlobalTypes.{builtinFunctions, builtinPrefix} -import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.RubyNode -import io.joern.rubysrc2cpg.datastructures.{BlockScope, MethodLikeScope, RubyProgramSummary, RubyScope, TypeLikeScope} -import io.joern.x2cpg.datastructures.NamespaceLikeScope -import io.joern.x2cpg.datastructures.Stack.* -import io.joern.x2cpg.{Ast, Defines, ValidationMode} +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{DummyNode, InstanceFieldIdentifier, MemberAccess, RubyNode} +import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl} +import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.* +import io.joern.rubysrc2cpg.passes.Defines import io.joern.rubysrc2cpg.passes.Defines.RubyOperators trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -28,23 +27,53 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As val name = code(node) val identifier = identifierNode(node, name, name, Defines.Any) val typeRef = scope.tryResolveTypeReference(name) - scope.lookupVariable(name) match { - case None if typeRef.isDefined => - Ast(identifier.typeFullName(typeRef.get.name)) - case None => - val local = localNode(node, name, name, Defines.Any) - scope.addToScope(name, local) match { - case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST) - case _ => + + node match { + case instanceField: InstanceFieldIdentifier => + scope.findFieldInScope(name) match { + case None => + scope.pushField(FieldDecl(name, Defines.Any, false, false, node)) + astForFieldAccess( + MemberAccess( + DummyNode(identifierNode(instanceField, Defines.This, Defines.This, Defines.Any))( + instanceField.span.spanStart(Defines.This) + ), + ".", + name + )(instanceField.span) + ) + case Some(field) => + val fieldNode = field.node + astForFieldAccess( + MemberAccess( + DummyNode(identifierNode(fieldNode, Defines.This, Defines.This, Defines.Any))( + instanceField.span.spanStart(Defines.This) + ), + ".", + name + )(fieldNode.span) + ) } - Ast(identifier).withRefEdge(identifier, local) - case Some(local) => - local match { - case x: NewLocal => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) - case x: NewMethodParameterIn => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) + case _ => + scope.lookupVariable(name) match { + case None if typeRef.isDefined => + Ast(identifier.typeFullName(typeRef.get.name)) + case None => + val local = localNode(node, name, name, Defines.Any) + scope.addToScope(name, local) match { + case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST) + case _ => + } + Ast(identifier).withRefEdge(identifier, local) + case Some(local) => + local match { + case x: NewLocal => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) + case x: NewMethodParameterIn => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName) + } + Ast(identifier).withRefEdge(identifier, local) } - Ast(identifier).withRefEdge(identifier, local) } + } protected val UnaryOperatorNames: Map[String, String] = Map( diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index ff2c513b0fa1..e02e67f44aee 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -1,18 +1,12 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{Unknown, *} -import io.joern.rubysrc2cpg.datastructures.BlockScope -import io.joern.rubysrc2cpg.passes.Defines +import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl} import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType} -import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines} +import io.joern.x2cpg.{Ast, Defines as XDefines, ValidationMode} +import io.joern.rubysrc2cpg.passes.Defines import io.shiftleft.codepropertygraph.generated.nodes.* -import io.shiftleft.codepropertygraph.generated.{ - ControlStructureTypes, - DiffGraphBuilder, - DispatchTypes, - Operators, - PropertyNames -} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators, PropertyNames} trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -28,7 +22,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case node: IndexAccess => astForIndexAccess(node) case node: SingleAssignment => astForSingleAssignment(node) case node: AttributeAssignment => astForAttributeAssignment(node) - case node: SimpleIdentifier => astForSimpleIdentifier(node) + case node: RubyIdentifier => astForSimpleIdentifier(node) case node: SimpleCall => astForSimpleCall(node) case node: RequireCall => astForRequireCall(node) case node: IncludeCall => astForIncludeCall(node) @@ -312,12 +306,12 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { astForMemberCallWithoutBlock(call, memberAccess) } - protected def astForSimpleIdentifier(node: SimpleIdentifier): Ast = { + protected def astForSimpleIdentifier(node: RubyNode with RubyIdentifier): Ast = { val name = code(node) - if (name.startsWith("@")) { + if (name.startsWith("@@")) { logger.warn( - s"Class (@@) and instance (@) variables are not handled as members yet, but are instead handled as simple identifier declarations. Found: $name" + s"Class (@@) are not handled as members yet, but are instead handled as simple identifier declarations. Found: $name" ) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala index bd9e6e3b7905..42d24d33e93e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForTypesCreator.scala @@ -62,11 +62,12 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: node match { case _: ModuleDeclaration => scope.pushNewScope(ModuleScope(classFullName)) - case _: TypeDeclaration => scope.pushNewScope(TypeScope(classFullName)) + case _: TypeDeclaration => scope.pushNewScope(TypeScope(classFullName, List.empty)) } val classBody = node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList) + val classBodyAsts = classBody.statements.flatMap(astsForStatement) match { case bodyAsts if scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST => val bodyStart = classBody.span.spanStart() @@ -77,9 +78,19 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this: methodDecl ++ bodyAsts case bodyAsts => bodyAsts } + + val fieldMemberNodes = node match { + case classDecl: ClassDeclaration => + classDecl.fields.map { x => + val name = code(x) + Ast(memberNode(x, name, name, Defines.Any)) + } + case _ => Seq.empty + } + scope.popScope() - Ast(typeDecl).withChildren(classBodyAsts) + Ast(typeDecl).withChildren(fieldMemberNodes).withChildren(classBodyAsts) } protected def astsForFieldDeclarations(node: FieldsDeclaration): Seq[Ast] = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index dd4da9cd5efd..d387f9273947 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -56,7 +56,12 @@ object RubyIntermediateAst { def baseClass: Option[RubyNode] = None } - final case class ClassDeclaration(name: RubyNode, baseClass: Option[RubyNode], body: RubyNode)(span: TextSpan) + final case class ClassDeclaration( + name: RubyNode, + baseClass: Option[RubyNode], + body: RubyNode, + fields: List[InstanceFieldIdentifier] + )(span: TextSpan) extends RubyNode(span) with TypeDeclaration @@ -127,6 +132,8 @@ object RubyIntermediateAst { */ sealed trait ControlFlowClause + sealed trait RubyIdentifier + final case class RescueExpression( body: RubyNode, rescueClauses: List[RubyNode], @@ -196,8 +203,13 @@ object RubyIntermediateAst { final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span) - /** Represents an unqualified identifier e.g. `X`, `x`, `@x`, `@@x`, `$x`, `$<`, etc. */ - final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan) extends RubyNode(span) + /** Represents an unqualified identifier e.g. `X`, `x`, `@@x`, `$x`, `$<`, etc. */ + final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan) + extends RubyNode(span) + with RubyIdentifier + + /** Represents a InstanceFieldIdentifier e.g `@x` */ + final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyIdentifier final case class SelfIdentifier()(span: TextSpan) extends RubyNode(span) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala index d8cfbbe08a7c..299911bf87d5 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/RubyScope.scala @@ -32,6 +32,23 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) */ def newProgramScope: Option[ProgramScope] = surroundingScopeFullName.map(ProgramScope.apply) + def pushField(field: FieldDecl): Unit = { + popScope().foreach { + case TypeScope(fullName, fields) => + pushNewScope(TypeScope(fullName, fields :+ field)) + case x => + pushField(field) + pushNewScope(x) + } + } + + def getFieldsInScope: List[FieldDecl] = + stack.collect { case ScopeElement(TypeScope(_, fields), _) => fields }.flatten + + def findFieldInScope(fieldName: String): Option[FieldDecl] = { + getFieldsInScope.find(_.name == fieldName) + } + override def pushNewScope(scopeNode: TypedScopeElement): Unit = { // Use the summary to determine if there is a constructor present val mappedScopeNode = scopeNode match { @@ -41,6 +58,9 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String]) case n: ProgramScope => typesInScope.addAll(summary.typesUnderNamespace(n.fullName)) n + case TypeScope(name, _) => + typesInScope.addAll(summary.matchingTypes(name)) + scopeNode case _ => scopeNode } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala index d1229e1882a3..182b95a617d7 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/datastructures/ScopeElement.scala @@ -1,5 +1,6 @@ package io.joern.rubysrc2cpg.datastructures +import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.RubyNode import io.joern.rubysrc2cpg.passes.Defines import io.joern.x2cpg.datastructures.{NamespaceLikeScope, TypedScopeElement} import io.shiftleft.codepropertygraph.generated.nodes.NewBlock @@ -10,6 +11,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewBlock */ case class NamespaceScope(fullName: String) extends NamespaceLikeScope +case class FieldDecl(name: String, typeFullName: String, isStatic: Boolean, isInitialized: Boolean, node: RubyNode) + extends TypedScopeElement + /** A type-like scope with a full name. */ trait TypeLikeScope extends TypedScopeElement { @@ -40,7 +44,7 @@ case class ModuleScope(fullName: String) extends TypeLikeScope * @param fullName * the type full name. */ -case class TypeScope(fullName: String) extends TypeLikeScope +case class TypeScope(fullName: String, fields: List[FieldDecl]) extends TypeLikeScope /** Represents scope objects that map to a method node. */ diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala index e53302411717..03b380bf06bd 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala @@ -535,7 +535,7 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitInstanceIdentifierVariable(ctx: RubyParser.InstanceIdentifierVariableContext): RubyNode = { - SimpleIdentifier()(ctx.toTextSpan) + InstanceFieldIdentifier()(ctx.toTextSpan) } override def visitLocalIdentifierVariable(ctx: RubyParser.LocalIdentifierVariableContext): RubyNode = { @@ -693,12 +693,79 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { )(ctx.toTextSpan) } + private def findInstanceFieldsInMethodDecls(methodDecls: List[MethodDeclaration]): List[InstanceFieldIdentifier] = { + // TODO: Handle case where body of method is not a StatementList + methodDecls + .flatMap { + _.body.asInstanceOf[StatementList].statements.collect { case x: SingleAssignment => + x.lhs + } + } + .collect { case x: InstanceFieldIdentifier => + x + } + } + + private def genInitFieldStmts( + ctxBodyStatement: RubyParser.BodyStatementContext + ): (RubyNode, List[InstanceFieldIdentifier]) = { + val loweredClassDecls = lowerSingletonClassDeclarations(ctxBodyStatement) + loweredClassDecls match { + case stmtList: StatementList => + val (instanceFields, rest) = stmtList.statements.partition { + case x: InstanceFieldIdentifier => true + case _ => false + } + + val methodDecls = rest.collect { case x: MethodDeclaration => + x + } + + val fieldsInMethodDecls = findInstanceFieldsInMethodDecls(methodDecls) + + val initializeMethod = methodDecls.collectFirst { x => + x.methodName match + case "initialize" => x + } + + val combinedInstanceFields = instanceFields ++ fieldsInMethodDecls + + val initStmtListStatements = combinedInstanceFields.map { x => + SingleAssignment(x, "=", StaticLiteral(getBuiltInType(Defines.NilClass))(x.span.spanStart("nil")))( + x.span.spanStart(s"${x.span.text} = nil") + ) + } + + val updatedStmtList = initializeMethod match { + case Some(initMethod) => + initMethod.body match { + // TODO: Filter out instance fields that are assigned an initial value in the constructor method. Current + // implementation leads to "double" assignment happening when the instance field is assigned a value + // where you end up having + // = nil; = ...; + case stmtList: StatementList => + StatementList(initStmtListStatements ++ stmtList.statements)(stmtList.span) + case x => x + } + case None => + val newInitMethod = + MethodDeclaration("initialize", List.empty, StatementList(initStmtListStatements)(stmtList.span))( + stmtList.span + ) + StatementList(newInitMethod +: stmtList.statements)(stmtList.span) + } + + (updatedStmtList, combinedInstanceFields.asInstanceOf[List[InstanceFieldIdentifier]]) + case decls => (decls, List.empty) + } + } + override def visitClassDefinition(ctx: RubyParser.ClassDefinitionContext): RubyNode = { - ClassDeclaration( - visit(ctx.classPath()), - Option(ctx.commandOrPrimaryValue()).map(visit), - lowerSingletonClassDeclarations(ctx.bodyStatement()) - )(ctx.toTextSpan) + val (stmts, fields) = genInitFieldStmts(ctx.bodyStatement()) + + ClassDeclaration(visit(ctx.classPath()), Option(ctx.commandOrPrimaryValue()).map(visit), stmts, fields)( + ctx.toTextSpan + ) } /** Lowers all MethodDeclaration found in SingletonClassDeclaration to SingletonMethodDeclaration. diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala index d681ba363a45..e97935881e6d 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/passes/Defines.scala @@ -21,6 +21,7 @@ object Defines { val Regexp: String = "Regexp" val Lambda: String = "lambda" val Proc: String = "proc" + val This: String = "this" val Program: String = ":program" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala index c1a3c27749ed..e00bc5c60d46 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala @@ -2,8 +2,10 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.joern.x2cpg.Defines -import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal, Return} +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, FieldIdentifier, Identifier, Literal, Return} import io.shiftleft.semanticcpg.language.* +import io.joern.rubysrc2cpg.passes.Defines as RubyDefines class ClassTests extends RubyCode2CpgFixture { @@ -389,4 +391,81 @@ class ClassTests extends RubyCode2CpgFixture { } } } + + "Instance variables in a class and method defs" should { + val cpg = code(""" + |class Foo + | @a + | + | def foo + | @b = 10 + | end + | + | def foobar + | @c = 20 + | @d = 40 + | end + | + | def barfoo + | puts @a + | puts @c + | @o = "a" + | end + |end + |""".stripMargin) + + "create respective member nodes" in { + inside(cpg.typeDecl.name("Foo").l) { + case fooType :: Nil => + inside(fooType.member.l) { + case aMember :: bMember :: cMember :: dMember :: oMember :: Nil => + // Test that all members in class are present + aMember.code shouldBe "@a" + bMember.code shouldBe "@b" + cMember.code shouldBe "@c" + dMember.code shouldBe "@d" + oMember.code shouldBe "@o" + case _ => fail("Expected 5 members") + } + case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}") + } + } + + "create nil assignments under the class initializer" in { + inside(cpg.typeDecl.name("Foo").l) { + case fooType :: Nil => + inside(fooType.method.name(Defines.ConstructorMethodName).l) { + case clinitMethod :: Nil => + inside(clinitMethod.block.astChildren.isCall.name(Operators.assignment).l) { + case aAssignment :: bAssignment :: cAssignment :: dAssignment :: oAssignment :: Nil => + aAssignment.code shouldBe "@a = nil" + + bAssignment.code shouldBe "@b = nil" + cAssignment.code shouldBe "@c = nil" + dAssignment.code shouldBe "@d = nil" + oAssignment.code shouldBe "@o = nil" + + inside(aAssignment.argument.l) { + case (lhs: Call) :: (rhs: Literal) :: Nil => + lhs.code shouldBe "this.@a" + lhs.methodFullName shouldBe Operators.fieldAccess + + inside(lhs.argument.l) { + case (identifier: Identifier) :: (fieldIdentifier: FieldIdentifier) :: Nil => + identifier.code shouldBe RubyDefines.This + fieldIdentifier.code shouldBe "@a" + case _ => fail("Expected identifier and fieldIdentifier for fieldAccess") + } + + rhs.code shouldBe "nil" + case _ => fail("Expected only LHS and RHS for assignment call") + } + case _ => fail("") + } + case xs => fail(s"Expected one method for clinit, instead got ${xs.name.mkString(", ")}") + } + case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}") + } + } + } }