Skip to content

Commit

Permalink
[ruby] Overapproximate Static Type References & Handle self (#4263)
Browse files Browse the repository at this point in the history
Frameworks such as Ruby on Rails implicitly loads classes in a way that requires more careful analysis, which is the goal of [#3940](#3940). For now, however, we can choose a layered approach of first looking at types in scope, then looking at the rest of the program.

This PR does this, which allows `astForMemberCall` to resolve method calls, as well as adds simple handling for occurrences of the `self` identifier.

Replaced `STATIC_DISPATCH` of member calls to `DYNAMIC_DISPATCH` as this can very easily be a call that needs to process a polymorphic relationship
  • Loading branch information
DavidBakerEffendi authored Mar 5, 2024
1 parent d83129b commit 7e5913c
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
protected def handleVariableOccurrence(node: RubyNode): Ast = {
val name = code(node)
val identifier = identifierNode(node, name, name, Defines.Any)
val typeRef = scope.tryResolveTypeReference(name)
scope.lookupVariable(name) match {
case None if typeRef.isDefined =>
Ast(identifier.typeFullName(typeRef.get.name))
case None =>
val local = localNode(node, name, name, Defines.Any)
scope.addToScope(name, local) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType}
import io.joern.x2cpg.{Ast, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators, PropertyNames}

trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand All @@ -33,6 +33,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: SplattingRubyNode => astForSplattingRubyNode(node)
case node: AnonymousTypeDeclaration => astForAnonymousTypeDeclaration(node)
case node: ProcOrLambdaExpr => astForProcOrLambdaExpr(node)
case node: SelfIdentifier => astForSelfIdentifier(node)
case node: DummyNode => Ast(node.node)
case _ => astForUnknown(node)

Expand Down Expand Up @@ -109,11 +110,24 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
astForMemberCall(MemberCall(node.target, node.op, node.methodName, List.empty)(node.span))
}

/** Attempts to extract a type from the base of a member call.
*/
protected def typeFromCallTarget(baseNode: RubyNode): Option[String] = {
astForExpression(baseNode).nodes
.flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString))
.filterNot(_ == XDefines.Any)
.headOption
}

protected def astForMemberCall(node: MemberCall): Ast = {
val fullName = node.methodName // TODO
// Use the scope type recovery to attempt to obtain a receiver type for the call
// TODO: Type recovery should potentially resolve this
val fullName = typeFromCallTarget(node.target)
.map(x => s"$x:${node.methodName}")
.getOrElse(node.methodName)
val fieldAccessAst = astForFieldAccess(MemberAccess(node.target, node.op, node.methodName)(node.span))
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val fieldAccessCall = callNode(node, code(node), node.methodName, fullName, DispatchTypes.STATIC_DISPATCH)
val fieldAccessCall = callNode(node, code(node), node.methodName, fullName, DispatchTypes.DYNAMIC_DISPATCH)
callAst(fieldAccessCall, argumentAsts, Some(fieldAccessAst))
}

Expand Down Expand Up @@ -413,6 +427,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
)
}

private def astForSelfIdentifier(node: SelfIdentifier): Ast = {
val thisIdentifier = identifierNode(node, "this", code(node), scope.surroundingTypeFullName.getOrElse(Defines.Any))
Ast(thisIdentifier)
}

protected def astForUnknown(node: RubyNode): Ast = {
val className = node.getClass.getSimpleName
val text = code(node)
Expand All @@ -421,11 +440,14 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = {
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.methodName
val methodFullName = methodName // TODO
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.STATIC_DISPATCH)
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.methodName
// TODO: Type recovery should potentially resolve this
val methodFullName = typeFromCallTarget(memberAccess.target)
.map(x => s"$x:$methodName")
.getOrElse(methodName)
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH)
callAst(call, argumentAsts, None, Some(receiverAst))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,18 @@ class RubyScope(summary: RubyProgramSummary)
case _ =>
}
}

override def tryResolveTypeReference(typeName: String): Option[RubyType] = {
// TODO: While we find better ways to understand how the implicit class loading works,
// we can approximate that all types are in scope in the mean time.
super.tryResolveTypeReference(typeName) match {
case None =>
summary.namespaceToType.flatMap(_._2).collectFirst {
case x if x.name.split("[.]").lastOption.contains(typeName) =>
typesInScope.addOne(x)
x
}
case x => x
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CaseTests extends RubyCode2CpgFixture {
case u: Unknown => "unknown"
case mExpr =>
val call @ List(_) = List(mExpr).isCall.l
call.methodFullName.l shouldBe List("===")
call.methodFullName.l shouldBe List("__builtin.Integer:===")
val List(lhs, rhs) = call.argument.l
rhs.code shouldBe "<tmp-0>"
val List(code) = List(lhs).isCall.argument(1).code.l
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.DispatchTypes
import io.shiftleft.codepropertygraph.generated.nodes.Call
import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier}
import io.shiftleft.semanticcpg.language.*

class FieldAccessTests extends RubyCode2CpgFixture {
Expand All @@ -26,4 +26,36 @@ class FieldAccessTests extends RubyCode2CpgFixture {
fieldAccessCall.receiver.l shouldBe List(fieldAccess)
}

"`self.x` should correctly create a `this` node field base" in {

// Example from railsgoat
val cpg = code("""
|class PaidTimeOff < ApplicationRecord
| belongs_to :user
| has_many :schedule, foreign_key: :user_id, primary_key: :user_id, dependent: :destroy
|
| def sick_days_remaining
| self.sick_days_earned - self.sick_days_taken
| end
|end
|""".stripMargin)

inside(cpg.fieldAccess.code("self.*").l) {
case sickDays :: _ =>
sickDays.code shouldBe "self.sick_days_earned"
sickDays.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH

inside(sickDays.argument.l) {
case (self: Identifier) :: sickDaysEarned :: Nil =>
self.name shouldBe "this"
self.code shouldBe "self"
self.typeFullName should endWith("PaidTimeOff")

sickDaysEarned.code shouldBe "sick_days_earned"
case xs => fail(s"Expected exactly two field access arguments, instead got [${xs.code.mkString(", ")}]")
}
case Nil => fail("Expected at least one field access with `self` base, but got none.")
}
}

}

0 comments on commit 7e5913c

Please sign in to comment.