Skip to content

Commit

Permalink
[ruby] Initial handling of Instance Variables (Dynamic Fields) (#4340)
Browse files Browse the repository at this point in the history
 * Instance variables (`@var`).
 * The general structure for instance variables

TODO:
 * Handle the case where instance variables are assigned in method bodies where the method body is not a `StatementList`
 * Filter out instance variables that are already assigned a default value in the constructor of the class. The current implementation does not account for this, so you end up with a double assignment in the constructor
  • Loading branch information
AndreiDreyer authored Mar 15, 2024
1 parent f3e60c6 commit 81b91eb
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.astcreation.GlobalTypes.{builtinFunctions, builtinPrefix}
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.RubyNode
import io.joern.rubysrc2cpg.datastructures.{BlockScope, MethodLikeScope, RubyProgramSummary, RubyScope, TypeLikeScope}
import io.joern.x2cpg.datastructures.NamespaceLikeScope
import io.joern.x2cpg.datastructures.Stack.*
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{DummyNode, InstanceFieldIdentifier, MemberAccess, RubyNode}
import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl}
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.RubyOperators

trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
Expand All @@ -28,23 +27,53 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
val name = code(node)
val identifier = identifierNode(node, name, name, Defines.Any)
val typeRef = scope.tryResolveTypeReference(name)
scope.lookupVariable(name) match {
case None if typeRef.isDefined =>
Ast(identifier.typeFullName(typeRef.get.name))
case None =>
val local = localNode(node, name, name, Defines.Any)
scope.addToScope(name, local) match {
case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case _ =>

node match {
case instanceField: InstanceFieldIdentifier =>
scope.findFieldInScope(name) match {
case None =>
scope.pushField(FieldDecl(name, Defines.Any, false, false, node))
astForFieldAccess(
MemberAccess(
DummyNode(identifierNode(instanceField, Defines.This, Defines.This, Defines.Any))(
instanceField.span.spanStart(Defines.This)
),
".",
name
)(instanceField.span)
)
case Some(field) =>
val fieldNode = field.node
astForFieldAccess(
MemberAccess(
DummyNode(identifierNode(fieldNode, Defines.This, Defines.This, Defines.Any))(
instanceField.span.spanStart(Defines.This)
),
".",
name
)(fieldNode.span)
)
}
Ast(identifier).withRefEdge(identifier, local)
case Some(local) =>
local match {
case x: NewLocal => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName)
case x: NewMethodParameterIn => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName)
case _ =>
scope.lookupVariable(name) match {
case None if typeRef.isDefined =>
Ast(identifier.typeFullName(typeRef.get.name))
case None =>
val local = localNode(node, name, name, Defines.Any)
scope.addToScope(name, local) match {
case BlockScope(block) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case _ =>
}
Ast(identifier).withRefEdge(identifier, local)
case Some(local) =>
local match {
case x: NewLocal => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName)
case x: NewMethodParameterIn => identifier.dynamicTypeHintFullName(x.dynamicTypeHintFullName)
}
Ast(identifier).withRefEdge(identifier, local)
}
Ast(identifier).withRefEdge(identifier, local)
}

}

protected val UnaryOperatorNames: Map[String, String] = Map(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{Unknown, *}
import io.joern.rubysrc2cpg.datastructures.BlockScope
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl}
import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType}
import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines}
import io.joern.x2cpg.{Ast, Defines as XDefines, ValidationMode}
import io.joern.rubysrc2cpg.passes.Defines
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{
ControlStructureTypes,
DiffGraphBuilder,
DispatchTypes,
Operators,
PropertyNames
}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators, PropertyNames}

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand All @@ -28,7 +22,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: IndexAccess => astForIndexAccess(node)
case node: SingleAssignment => astForSingleAssignment(node)
case node: AttributeAssignment => astForAttributeAssignment(node)
case node: SimpleIdentifier => astForSimpleIdentifier(node)
case node: RubyIdentifier => astForSimpleIdentifier(node)
case node: SimpleCall => astForSimpleCall(node)
case node: RequireCall => astForRequireCall(node)
case node: IncludeCall => astForIncludeCall(node)
Expand Down Expand Up @@ -312,12 +306,12 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
astForMemberCallWithoutBlock(call, memberAccess)
}

protected def astForSimpleIdentifier(node: SimpleIdentifier): Ast = {
protected def astForSimpleIdentifier(node: RubyNode with RubyIdentifier): Ast = {
val name = code(node)

if (name.startsWith("@")) {
if (name.startsWith("@@")) {
logger.warn(
s"Class (@@) and instance (@) variables are not handled as members yet, but are instead handled as simple identifier declarations. Found: $name"
s"Class (@@) are not handled as members yet, but are instead handled as simple identifier declarations. Found: $name"
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:

node match {
case _: ModuleDeclaration => scope.pushNewScope(ModuleScope(classFullName))
case _: TypeDeclaration => scope.pushNewScope(TypeScope(classFullName))
case _: TypeDeclaration => scope.pushNewScope(TypeScope(classFullName, List.empty))
}

val classBody =
node.body.asInstanceOf[StatementList] // for now (bodyStatement is a superset of stmtList)

val classBodyAsts = classBody.statements.flatMap(astsForStatement) match {
case bodyAsts if scope.shouldGenerateDefaultConstructor && this.parseLevel == AstParseLevel.FULL_AST =>
val bodyStart = classBody.span.spanStart()
Expand All @@ -77,9 +78,19 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode) { this:
methodDecl ++ bodyAsts
case bodyAsts => bodyAsts
}

val fieldMemberNodes = node match {
case classDecl: ClassDeclaration =>
classDecl.fields.map { x =>
val name = code(x)
Ast(memberNode(x, name, name, Defines.Any))
}
case _ => Seq.empty
}

scope.popScope()

Ast(typeDecl).withChildren(classBodyAsts)
Ast(typeDecl).withChildren(fieldMemberNodes).withChildren(classBodyAsts)
}

protected def astsForFieldDeclarations(node: FieldsDeclaration): Seq[Ast] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ object RubyIntermediateAst {
def baseClass: Option[RubyNode] = None
}

final case class ClassDeclaration(name: RubyNode, baseClass: Option[RubyNode], body: RubyNode)(span: TextSpan)
final case class ClassDeclaration(
name: RubyNode,
baseClass: Option[RubyNode],
body: RubyNode,
fields: List[InstanceFieldIdentifier]
)(span: TextSpan)
extends RubyNode(span)
with TypeDeclaration

Expand Down Expand Up @@ -127,6 +132,8 @@ object RubyIntermediateAst {
*/
sealed trait ControlFlowClause

sealed trait RubyIdentifier

final case class RescueExpression(
body: RubyNode,
rescueClauses: List[RubyNode],
Expand Down Expand Up @@ -196,8 +203,13 @@ object RubyIntermediateAst {

final case class ReturnExpression(expressions: List[RubyNode])(span: TextSpan) extends RubyNode(span)

/** Represents an unqualified identifier e.g. `X`, `x`, `@x`, `@@x`, `$x`, `$<`, etc. */
final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan) extends RubyNode(span)
/** Represents an unqualified identifier e.g. `X`, `x`, `@@x`, `$x`, `$<`, etc. */
final case class SimpleIdentifier(typeFullName: Option[String] = None)(span: TextSpan)
extends RubyNode(span)
with RubyIdentifier

/** Represents a InstanceFieldIdentifier e.g `@x` */
final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyIdentifier

final case class SelfIdentifier()(span: TextSpan) extends RubyNode(span)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String])
*/
def newProgramScope: Option[ProgramScope] = surroundingScopeFullName.map(ProgramScope.apply)

def pushField(field: FieldDecl): Unit = {
popScope().foreach {
case TypeScope(fullName, fields) =>
pushNewScope(TypeScope(fullName, fields :+ field))
case x =>
pushField(field)
pushNewScope(x)
}
}

def getFieldsInScope: List[FieldDecl] =
stack.collect { case ScopeElement(TypeScope(_, fields), _) => fields }.flatten

def findFieldInScope(fieldName: String): Option[FieldDecl] = {
getFieldsInScope.find(_.name == fieldName)
}

override def pushNewScope(scopeNode: TypedScopeElement): Unit = {
// Use the summary to determine if there is a constructor present
val mappedScopeNode = scopeNode match {
Expand All @@ -41,6 +58,9 @@ class RubyScope(summary: RubyProgramSummary, projectRoot: Option[String])
case n: ProgramScope =>
typesInScope.addAll(summary.typesUnderNamespace(n.fullName))
n
case TypeScope(name, _) =>
typesInScope.addAll(summary.matchingTypes(name))
scopeNode
case _ => scopeNode
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.joern.rubysrc2cpg.datastructures

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.RubyNode
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.datastructures.{NamespaceLikeScope, TypedScopeElement}
import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
Expand All @@ -10,6 +11,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
*/
case class NamespaceScope(fullName: String) extends NamespaceLikeScope

case class FieldDecl(name: String, typeFullName: String, isStatic: Boolean, isInitialized: Boolean, node: RubyNode)
extends TypedScopeElement

/** A type-like scope with a full name.
*/
trait TypeLikeScope extends TypedScopeElement {
Expand Down Expand Up @@ -40,7 +44,7 @@ case class ModuleScope(fullName: String) extends TypeLikeScope
* @param fullName
* the type full name.
*/
case class TypeScope(fullName: String) extends TypeLikeScope
case class TypeScope(fullName: String, fields: List[FieldDecl]) extends TypeLikeScope

/** Represents scope objects that map to a method node.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitInstanceIdentifierVariable(ctx: RubyParser.InstanceIdentifierVariableContext): RubyNode = {
SimpleIdentifier()(ctx.toTextSpan)
InstanceFieldIdentifier()(ctx.toTextSpan)
}

override def visitLocalIdentifierVariable(ctx: RubyParser.LocalIdentifierVariableContext): RubyNode = {
Expand Down Expand Up @@ -693,12 +693,79 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
)(ctx.toTextSpan)
}

private def findInstanceFieldsInMethodDecls(methodDecls: List[MethodDeclaration]): List[InstanceFieldIdentifier] = {
// TODO: Handle case where body of method is not a StatementList
methodDecls
.flatMap {
_.body.asInstanceOf[StatementList].statements.collect { case x: SingleAssignment =>
x.lhs
}
}
.collect { case x: InstanceFieldIdentifier =>
x
}
}

private def genInitFieldStmts(
ctxBodyStatement: RubyParser.BodyStatementContext
): (RubyNode, List[InstanceFieldIdentifier]) = {
val loweredClassDecls = lowerSingletonClassDeclarations(ctxBodyStatement)
loweredClassDecls match {
case stmtList: StatementList =>
val (instanceFields, rest) = stmtList.statements.partition {
case x: InstanceFieldIdentifier => true
case _ => false
}

val methodDecls = rest.collect { case x: MethodDeclaration =>
x
}

val fieldsInMethodDecls = findInstanceFieldsInMethodDecls(methodDecls)

val initializeMethod = methodDecls.collectFirst { x =>
x.methodName match
case "initialize" => x
}

val combinedInstanceFields = instanceFields ++ fieldsInMethodDecls

val initStmtListStatements = combinedInstanceFields.map { x =>
SingleAssignment(x, "=", StaticLiteral(getBuiltInType(Defines.NilClass))(x.span.spanStart("nil")))(
x.span.spanStart(s"${x.span.text} = nil")
)
}

val updatedStmtList = initializeMethod match {
case Some(initMethod) =>
initMethod.body match {
// TODO: Filter out instance fields that are assigned an initial value in the constructor method. Current
// implementation leads to "double" assignment happening when the instance field is assigned a value
// where you end up having
// <instanceField> = nil; <instanceField> = ...;
case stmtList: StatementList =>
StatementList(initStmtListStatements ++ stmtList.statements)(stmtList.span)
case x => x
}
case None =>
val newInitMethod =
MethodDeclaration("initialize", List.empty, StatementList(initStmtListStatements)(stmtList.span))(
stmtList.span
)
StatementList(newInitMethod +: stmtList.statements)(stmtList.span)
}

(updatedStmtList, combinedInstanceFields.asInstanceOf[List[InstanceFieldIdentifier]])
case decls => (decls, List.empty)
}
}

override def visitClassDefinition(ctx: RubyParser.ClassDefinitionContext): RubyNode = {
ClassDeclaration(
visit(ctx.classPath()),
Option(ctx.commandOrPrimaryValue()).map(visit),
lowerSingletonClassDeclarations(ctx.bodyStatement())
)(ctx.toTextSpan)
val (stmts, fields) = genInitFieldStmts(ctx.bodyStatement())

ClassDeclaration(visit(ctx.classPath()), Option(ctx.commandOrPrimaryValue()).map(visit), stmts, fields)(
ctx.toTextSpan
)
}

/** Lowers all MethodDeclaration found in SingletonClassDeclaration to SingletonMethodDeclaration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ object Defines {
val Regexp: String = "Regexp"
val Lambda: String = "lambda"
val Proc: String = "proc"
val This: String = "this"

val Program: String = ":program"

Expand Down
Loading

0 comments on commit 81b91eb

Please sign in to comment.