Skip to content

Commit

Permalink
[ruby] Class Variables (#4343)
Browse files Browse the repository at this point in the history
* [ruby] Initial handling of class variables

* [ruby] Handling of class vars done, added test as well

* code cleanup

* fixed broken test
  • Loading branch information
AndreiDreyer authored Mar 15, 2024
1 parent 81b91eb commit 3b4dcc5
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.astcreation.GlobalTypes.{builtinFunctions, builtinPrefix}
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{DummyNode, InstanceFieldIdentifier, MemberAccess, RubyNode}
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
ClassFieldIdentifier,
DummyNode,
InstanceFieldIdentifier,
MemberAccess,
RubyFieldIdentifier,
RubyNode
}
import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl}
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators}
Expand All @@ -23,36 +30,34 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
protected def prefixAsBuiltin(x: String): String = s"$builtinPrefix$pathSep$x"
protected def pathSep = "."

private def astForFieldInstance(name: String, node: RubyNode with RubyFieldIdentifier): Ast = {
val identName = node match {
case _: InstanceFieldIdentifier => Defines.This
case _: ClassFieldIdentifier => scope.surroundingTypeFullName.map(_.split("[.]").last).getOrElse(Defines.Any)
}

astForFieldAccess(
MemberAccess(
DummyNode(identifierNode(node, identName, identName, Defines.Any))(node.span.spanStart(identName)),
".",
name
)(node.span)
)
}

protected def handleVariableOccurrence(node: RubyNode): Ast = {
val name = code(node)
val identifier = identifierNode(node, name, name, Defines.Any)
val typeRef = scope.tryResolveTypeReference(name)

node match {
case instanceField: InstanceFieldIdentifier =>
case fieldVariable: RubyFieldIdentifier =>
scope.findFieldInScope(name) match {
case None =>
scope.pushField(FieldDecl(name, Defines.Any, false, false, node))
astForFieldAccess(
MemberAccess(
DummyNode(identifierNode(instanceField, Defines.This, Defines.This, Defines.Any))(
instanceField.span.spanStart(Defines.This)
),
".",
name
)(instanceField.span)
)
scope.pushField(FieldDecl(name, Defines.Any, false, false, fieldVariable))
astForFieldInstance(name, fieldVariable)
case Some(field) =>
val fieldNode = field.node
astForFieldAccess(
MemberAccess(
DummyNode(identifierNode(fieldNode, Defines.This, Defines.This, Defines.Any))(
instanceField.span.spanStart(Defines.This)
),
".",
name
)(fieldNode.span)
)
astForFieldInstance(name, field.node)
}
case _ =>
scope.lookupVariable(name) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,6 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
protected def astForSimpleIdentifier(node: RubyNode with RubyIdentifier): Ast = {
val name = code(node)

if (name.startsWith("@@")) {
logger.warn(
s"Class (@@) are not handled as members yet, but are instead handled as simple identifier declarations. Found: $name"
)
}

scope.lookupVariable(name) match {
case None if scope.tryResolveMethodInvocation(node.text, List.empty).isDefined =>
astForSimpleCall(SimpleCall(node, List())(node.span))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object RubyIntermediateAst {
name: RubyNode,
baseClass: Option[RubyNode],
body: RubyNode,
fields: List[InstanceFieldIdentifier]
fields: List[RubyNode with RubyFieldIdentifier]
)(span: TextSpan)
extends RubyNode(span)
with TypeDeclaration
Expand Down Expand Up @@ -132,8 +132,14 @@ object RubyIntermediateAst {
*/
sealed trait ControlFlowClause

/** Any structure that is an Identifier, except self. e.g. `a`, `@a`, `@@a`
*/
sealed trait RubyIdentifier

/** Ruby Instance or Class Variable Identifiers: `@a`, `@@a`
*/
sealed trait RubyFieldIdentifier extends RubyIdentifier

final case class RescueExpression(
body: RubyNode,
rescueClauses: List[RubyNode],
Expand Down Expand Up @@ -209,7 +215,10 @@ object RubyIntermediateAst {
with RubyIdentifier

/** Represents a InstanceFieldIdentifier e.g `@x` */
final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyIdentifier
final case class InstanceFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyFieldIdentifier

/** Represents a ClassFieldIdentifier e.g `@@x` */
final case class ClassFieldIdentifier()(span: TextSpan) extends RubyNode(span) with RubyFieldIdentifier

final case class SelfIdentifier()(span: TextSpan) extends RubyNode(span)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.joern.rubysrc2cpg.datastructures

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.RubyNode
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{RubyFieldIdentifier, RubyNode}
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.datastructures.{NamespaceLikeScope, TypedScopeElement}
import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
Expand All @@ -11,8 +11,13 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewBlock
*/
case class NamespaceScope(fullName: String) extends NamespaceLikeScope

case class FieldDecl(name: String, typeFullName: String, isStatic: Boolean, isInitialized: Boolean, node: RubyNode)
extends TypedScopeElement
case class FieldDecl(
name: String,
typeFullName: String,
isStatic: Boolean,
isInitialized: Boolean,
node: RubyNode with RubyFieldIdentifier
) extends TypedScopeElement

/** A type-like scope with a full name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.RangeOperatorContext
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode}
import io.joern.x2cpg.Defines as XDefines;

import scala.jdk.CollectionConverters.*

Expand Down Expand Up @@ -531,7 +532,7 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitClassIdentifierVariable(ctx: RubyParser.ClassIdentifierVariableContext): RubyNode = {
SimpleIdentifier()(ctx.toTextSpan)
ClassFieldIdentifier()(ctx.toTextSpan)
}

override def visitInstanceIdentifierVariable(ctx: RubyParser.InstanceIdentifierVariableContext): RubyNode = {
Expand Down Expand Up @@ -693,48 +694,80 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
)(ctx.toTextSpan)
}

private def findInstanceFieldsInMethodDecls(methodDecls: List[MethodDeclaration]): List[InstanceFieldIdentifier] = {
private def findFieldsInmethodDecls(methodDecls: List[MethodDeclaration]): List[RubyNode with RubyFieldIdentifier] = {
// TODO: Handle case where body of method is not a StatementList
methodDecls
.flatMap {
_.body.asInstanceOf[StatementList].statements.collect { case x: SingleAssignment =>
x.lhs
}
}
.collect { case x: InstanceFieldIdentifier =>
.collect { case x: RubyNode with RubyFieldIdentifier =>
x
}
}

private def genInitFieldStmts(
ctxBodyStatement: RubyParser.BodyStatementContext
): (RubyNode, List[InstanceFieldIdentifier]) = {
): (RubyNode, List[RubyNode with RubyFieldIdentifier]) = {
val loweredClassDecls = lowerSingletonClassDeclarations(ctxBodyStatement)

/** Generates SingleAssignment RubyNodes for list of fields and fields found in method decls
* @param fields
* @param fieldsInMethodDecls
* @return
*/
def genSingleAssignmentStmtList(
fields: List[RubyNode],
fieldsInMethodDecls: List[RubyNode]
): List[SingleAssignment] = {
(fields ++ fieldsInMethodDecls).map { x =>
SingleAssignment(x, "=", StaticLiteral(getBuiltInType(Defines.NilClass))(x.span.spanStart("nil")))(
x.span.spanStart(s"${x.span.text} = nil")
)
}
}

/** Partition RubyFields into InstanceFieldIdentifiers and ClassFieldIdentifiers
* @param fields
* @return
*/
def partitionRubyFields(fields: List[RubyNode]): (List[RubyNode], List[RubyNode]) = {
fields.partition {
case _: InstanceFieldIdentifier => true
case _ => false
}
}

loweredClassDecls match {
case stmtList: StatementList =>
val (instanceFields, rest) = stmtList.statements.partition {
case x: InstanceFieldIdentifier => true
case _ => false
val (rubyFieldIdentifiers, rest) = stmtList.statements.partition {
case x: RubyNode with RubyFieldIdentifier => true
case _ => false
}

val (instanceFields, classFields) = partitionRubyFields(rubyFieldIdentifiers)

val methodDecls = rest.collect { case x: MethodDeclaration =>
x
}

val fieldsInMethodDecls = findInstanceFieldsInMethodDecls(methodDecls)
val fieldsInMethodDecls = findFieldsInmethodDecls(methodDecls)

val (instanceFieldsInMethodDecls, classFieldsInMethodDecls) = partitionRubyFields(fieldsInMethodDecls)

val initializeMethod = methodDecls.collectFirst { x =>
x.methodName match
case "initialize" => x
}

val combinedInstanceFields = instanceFields ++ fieldsInMethodDecls
val initStmtListStatements = genSingleAssignmentStmtList(instanceFields, instanceFieldsInMethodDecls)
val clinitStmtList = genSingleAssignmentStmtList(classFields, classFieldsInMethodDecls)

val initStmtListStatements = combinedInstanceFields.map { x =>
SingleAssignment(x, "=", StaticLiteral(getBuiltInType(Defines.NilClass))(x.span.spanStart("nil")))(
x.span.spanStart(s"${x.span.text} = nil")
val clinitMethod =
MethodDeclaration(XDefines.StaticInitMethodName, List.empty, StatementList(clinitStmtList)(stmtList.span))(
stmtList.span
)
}

val updatedStmtList = initializeMethod match {
case Some(initMethod) =>
Expand All @@ -744,18 +777,21 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
// where you end up having
// <instanceField> = nil; <instanceField> = ...;
case stmtList: StatementList =>
StatementList(initStmtListStatements ++ stmtList.statements)(stmtList.span)
val initializers = initStmtListStatements :+ clinitMethod
StatementList(initializers ++ rest)(stmtList.span)
case x => x
}
case None =>
val newInitMethod =
MethodDeclaration("initialize", List.empty, StatementList(initStmtListStatements)(stmtList.span))(
stmtList.span
)
StatementList(newInitMethod +: stmtList.statements)(stmtList.span)
val initializers = newInitMethod :: clinitMethod :: Nil
StatementList(initializers ++ rest)(stmtList.span)
}
val combinedFields = rubyFieldIdentifiers ++ fieldsInMethodDecls

(updatedStmtList, combinedInstanceFields.asInstanceOf[List[InstanceFieldIdentifier]])
(updatedStmtList, combinedFields.asInstanceOf[List[RubyNode with RubyFieldIdentifier]])
case decls => (decls, List.empty)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ClassTests extends RubyCode2CpgFixture {
classC.lineNumber shouldBe Some(2)
classC.baseType.l shouldBe List()
classC.member.l shouldBe List()
classC.method.name.l shouldBe List("<init>")
classC.method.name.l shouldBe List("<init>", "<clinit>")
}

"`class C < D` is represented by a TYPE_DECL node inheriting from `D`" in {
Expand All @@ -37,7 +37,7 @@ class ClassTests extends RubyCode2CpgFixture {
classC.fullName shouldBe "Test0.rb:<global>::program.C"
classC.lineNumber shouldBe Some(2)
classC.member.l shouldBe List()
classC.method.name.l shouldBe List("<init>")
classC.method.name.l shouldBe List("<init>", "<clinit>")

val List(typeD) = classC.baseType.l
typeD.name shouldBe "D"
Expand Down Expand Up @@ -415,6 +415,7 @@ class ClassTests extends RubyCode2CpgFixture {
|""".stripMargin)

"create respective member nodes" in {
cpg.typeDecl.nameExact("Foo").dotAst.l.foreach(println)
inside(cpg.typeDecl.name("Foo").l) {
case fooType :: Nil =>
inside(fooType.member.l) {
Expand All @@ -435,8 +436,8 @@ class ClassTests extends RubyCode2CpgFixture {
inside(cpg.typeDecl.name("Foo").l) {
case fooType :: Nil =>
inside(fooType.method.name(Defines.ConstructorMethodName).l) {
case clinitMethod :: Nil =>
inside(clinitMethod.block.astChildren.isCall.name(Operators.assignment).l) {
case initMethod :: Nil =>
inside(initMethod.block.astChildren.isCall.name(Operators.assignment).l) {
case aAssignment :: bAssignment :: cAssignment :: dAssignment :: oAssignment :: Nil =>
aAssignment.code shouldBe "@a = nil"

Expand All @@ -457,6 +458,83 @@ class ClassTests extends RubyCode2CpgFixture {
case _ => fail("Expected identifier and fieldIdentifier for fieldAccess")
}

rhs.code shouldBe "nil"
case _ => fail("Expected only LHS and RHS for assignment call")
}
case _ => fail("")
}
case xs => fail(s"Expected one method for init, instead got ${xs.name.mkString(", ")}")
}
case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}")
}
}
}

"Class Variables in Class and Methods" should {
val cpg = code("""
|class Foo
| @@a
|
| def foo
| @@b = 10
| end
|
| def foobar
| @@c = 20
| @@d = 40
| end
|
| def barfoo
| puts @@a
| puts @@c
| @@o = "a"
| end
|end
|""".stripMargin)

"create respective member nodes" in {
inside(cpg.typeDecl.name("Foo").l) {
case fooType :: Nil =>
inside(fooType.member.l) {
case aMember :: bMember :: cMember :: dMember :: oMember :: Nil =>
// Test that all members in class are present
aMember.code shouldBe "@@a"
bMember.code shouldBe "@@b"
cMember.code shouldBe "@@c"
dMember.code shouldBe "@@d"
oMember.code shouldBe "@@o"
case _ => fail("Expected 5 members")
}
case xs => fail(s"Expected TypeDecl for Foo, instead got ${xs.name.mkString(", ")}")
}
}

"create nil assignments under the class initializer" in {
inside(cpg.typeDecl.name("Foo").l) {
case fooType :: Nil =>
inside(fooType.method.name(Defines.StaticInitMethodName).l) {
case clinitMethod :: Nil =>
inside(clinitMethod.block.astChildren.isCall.name(Operators.assignment).l) {
case aAssignment :: bAssignment :: cAssignment :: dAssignment :: oAssignment :: Nil =>
aAssignment.code shouldBe "@@a = nil"

bAssignment.code shouldBe "@@b = nil"
cAssignment.code shouldBe "@@c = nil"
dAssignment.code shouldBe "@@d = nil"
oAssignment.code shouldBe "@@o = nil"

inside(aAssignment.argument.l) {
case (lhs: Call) :: (rhs: Literal) :: Nil =>
lhs.code shouldBe "Foo.@@a"
lhs.methodFullName shouldBe Operators.fieldAccess

inside(lhs.argument.l) {
case (identifier: Identifier) :: (fieldIdentifier: FieldIdentifier) :: Nil =>
identifier.code shouldBe "Foo"
fieldIdentifier.code shouldBe "@@a"
case _ => fail("Expected identifier and fieldIdentifier for fieldAccess")
}

rhs.code shouldBe "nil"
case _ => fail("Expected only LHS and RHS for assignment call")
}
Expand Down

0 comments on commit 3b4dcc5

Please sign in to comment.