Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ruby] Method Aliases #4373

Merged
merged 7 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ CONSTANT_IDENTIFIER
;

fragment METHOD_ONLY_IDENTIFIER
: (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) ('!' | '?')
: (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER) (EMARK | QMARK | EQ)
;


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ methodName
;

methodOnlyIdentifier
: (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER | pseudoVariable) (EMARK | QMARK)
: (CONSTANT_IDENTIFIER | LOCAL_VARIABLE_IDENTIFIER | pseudoVariable) (EMARK | QMARK | EQ)
;

methodInvocationWithoutParentheses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case node: MemberCallWithBlock => returnAstForRubyCall(node)
case node: SimpleCallWithBlock => returnAstForRubyCall(node)
case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | IndexAccess | Association |
YieldExpr | RubyCall) =>
YieldExpr | RubyCall | RubyFieldIdentifier) =>
astForReturnStatement(ReturnExpression(List(node))(node.span)) :: Nil
case node: SingleAssignment =>
astForSingleAssignment(node) :: List(astForReturnStatement(ReturnExpression(List(node.lhs))(node.span)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,7 @@ object RubyIntermediateAst {
final case class BinaryExpression(lhs: RubyNode, op: String, rhs: RubyNode)(span: TextSpan) extends RubyNode(span)

final case class HereDocNode(content: String)(span: TextSpan) extends RubyNode(span)

final case class AliasStatement(oldName: String, newName: String)(span: TextSpan) extends RubyNode(span)

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.*
import io.joern.rubysrc2cpg.parser.AntlrContextHelpers.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.rubysrc2cpg.utils.FreshNameGenerator
import io.joern.x2cpg.Defines as XDefines
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode}
import io.joern.x2cpg.Defines as XDefines;
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters.*
import io.joern.rubysrc2cpg.utils.FreshNameGenerator

/** Converts an ANTLR Ruby Parse Tree into the intermediate Ruby AST.
*/
class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {

private val logger = LoggerFactory.getLogger(getClass)
private val classNameGen = FreshNameGenerator(id => s"<anon-class-$id>")

protected def freshClassName(span: TextSpan): SimpleIdentifier = {
SimpleIdentifier(None)(span.spanStart(classNameGen.fresh))
}
Expand Down Expand Up @@ -474,6 +477,15 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
RequireCall(visit(identifierCtx), argument, true)(ctx.toTextSpan)
case ("include", List(argument)) =>
IncludeCall(visit(identifierCtx), argument)(ctx.toTextSpan)
case (idAssign, arguments) if idAssign.endsWith("=") =>
// fixme: This workaround handles a parser ambiguity with method identifiers having `=` and assignments.
// The Ruby parser gives precedence to assignments over methods called with this suffix however
val lhsIdentifier = SimpleIdentifier(None)(identifierCtx.toTextSpan.spanStart(idAssign.stripSuffix("=")))
val argNode = arguments match {
case arg :: Nil => arg
case xs => ArrayLiteral(xs)(ctx.commandArgument().toTextSpan)
}
SingleAssignment(lhsIdentifier, "=", argNode)(ctx.toTextSpan)
case _ =>
SimpleCall(visit(identifierCtx), arguments)(ctx.toTextSpan)
}
Expand Down Expand Up @@ -724,9 +736,6 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
val loweredClassDecls = lowerSingletonClassDeclarations(ctxBodyStatement)

/** Generates SingleAssignment RubyNodes for list of fields and fields found in method decls
* @param fields
* @param fieldsInMethodDecls
* @return
*/
def genSingleAssignmentStmtList(
fields: List[RubyNode],
Expand All @@ -740,8 +749,6 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

/** Partition RubyFields into InstanceFieldIdentifiers and ClassFieldIdentifiers
* @param fields
* @return
*/
def partitionRubyFields(fields: List[RubyNode]): (List[RubyNode], List[RubyNode]) = {
fields.partition {
Expand Down Expand Up @@ -807,8 +814,59 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}
}

/** Detects the alias statements and creates methods that reference the aliased method as a call.
* @param classBody
* the class body node
* @return
* the class body as a statement list.
*/
private def lowerAliasStatementsToMethods(classBody: RubyNode): StatementList = {

val classBodyStmts = classBody match {
case StatementList(stmts) => stmts
case x => List(x)
}

val methodParamMap = classBodyStmts.collect { case method: MethodDeclaration =>
method.methodName -> method.parameters
}.toMap

val loweredMethods = classBodyStmts.collect { case alias: AliasStatement =>
methodParamMap.get(alias.oldName) match {
case Some(aliasingMethodParams) =>
val argsCode = aliasingMethodParams.map(_.text).mkString(", ")
val callCode = s"${alias.oldName}($argsCode)"
MethodDeclaration(
alias.newName,
aliasingMethodParams,
StatementList(
SimpleCall(
SimpleIdentifier(None)(alias.span.spanStart(alias.oldName)),
aliasingMethodParams.map { x => SimpleIdentifier(None)(alias.span.spanStart(x.span.text)) }
)(alias.span.spanStart(callCode)) :: Nil
)(alias.span.spanStart(callCode))
)(alias.span.spanStart(s"def ${alias.newName}($argsCode)"))
case None =>
logger.warn(
s"Unable to correctly lower aliased method ${alias.oldName}, the result will be in degraded parameter/argument flows"
)
MethodDeclaration(
alias.newName,
List.empty,
StatementList(
SimpleCall(SimpleIdentifier(None)(alias.span.spanStart(alias.oldName)), List.empty)(alias.span) :: Nil
)(alias.span)
)(alias.span)
}
}

StatementList(classBodyStmts.filterNot(_.isInstanceOf[AliasStatement]) ++ loweredMethods)(classBody.span)
}

override def visitClassDefinition(ctx: RubyParser.ClassDefinitionContext): RubyNode = {
val (stmts, fields) = genInitFieldStmts(ctx.bodyStatement())
val (nonFieldStmts, fields) = genInitFieldStmts(ctx.bodyStatement())

val stmts = lowerAliasStatementsToMethods(nonFieldStmts)

ClassDeclaration(visit(ctx.classPath()), Option(ctx.commandOrPrimaryValue()).map(visit), stmts, fields)(
ctx.toTextSpan
Expand Down Expand Up @@ -969,4 +1027,8 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}
}

override def visitAliasStatement(ctx: RubyParser.AliasStatementContext): RubyNode = {
AliasStatement(ctx.oldName.getText, ctx.newName.getText)(ctx.toTextSpan)
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.Return
import io.shiftleft.semanticcpg.language.*
Expand Down Expand Up @@ -256,6 +257,53 @@ class MethodTests extends RubyCode2CpgFixture {

}

"aliased methods" should {

val cpg = code("""
|class Foo
| @x = 0
| def x=(z)
| @x = z
| end
|
| def x
| @x
| end
|
| alias x= bar=
|end
|
|foo = Foo.new
|
|foo.bar= 1
|
|puts foo.x # => 1
|""".stripMargin)

"create a method under `Foo` for both `x=`, `x`, and `bar=`, where `bar=` forwards parameters to a call to `x=`" in {
inside(cpg.typeDecl("Foo").l) {
case foo :: Nil =>
inside(foo.method.nameNot(Defines.ConstructorMethodName, Defines.StaticInitMethodName).l) {
case xeq :: x :: bar :: Nil =>
xeq.name shouldBe "x="
x.name shouldBe "x"
bar.name shouldBe "bar="

xeq.parameter.name.l shouldBe bar.parameter.name.l
// bar forwards parameters to a call to the aliased method
inside(bar.call.name("x=").l) {
case barCall :: Nil =>
barCall.argument.isIdentifier.name.head shouldBe "z"
barCall.code shouldBe "x=(z)"
case xs => fail(s"Expected a single call to `bar=`, instead got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected a three virtual methods under `Foo`, instead got [${xs.code.mkString(",")}]")
}
case xs => fail(s"Expected a single type decl for `Foo`, instead got [${xs.code.mkString(",")}]")
}
}
}

"Singleton Methods for module scope" should {
val cpg = code("""
|module F
Expand Down Expand Up @@ -308,6 +356,7 @@ class MethodTests extends RubyCode2CpgFixture {
case _ => fail("Expected one Method for :program")
}
}

}

}
Loading