Skip to content

Commit

Permalink
Revert "[rubysrc2cpg] Fixed Bug with Higher-Order Functions (joernio#…
Browse files Browse the repository at this point in the history
…3708)"

This reverts commit 3e9ea97.
  • Loading branch information
khemrajrathore committed Oct 10, 2023
1 parent 0a024af commit a4edfce
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,23 @@ class AstCreator(

def astForExpressionContext(ctx: ExpressionContext): Seq[Ast] = ctx match {
case ctx: PrimaryExpressionContext => astForPrimaryContext(ctx.primary())
case ctx: UnaryExpressionContext => astForUnaryExpression(ctx)
case ctx: PowerExpressionContext => astForPowerExpression(ctx)
case ctx: UnaryMinusExpressionContext => astForUnaryMinusExpression(ctx)
case ctx: MultiplicativeExpressionContext => astForMultiplicativeExpression(ctx)
case ctx: AdditiveExpressionContext => astForAdditiveExpression(ctx)
case ctx: BitwiseShiftExpressionContext => astForBitwiseShiftExpression(ctx)
case ctx: BitwiseAndExpressionContext => astForBitwiseAndExpression(ctx)
case ctx: BitwiseOrExpressionContext => astForBitwiseOrExpression(ctx)
case ctx: RelationalExpressionContext => astForRelationalExpression(ctx)
case ctx: EqualityExpressionContext => astForEqualityExpression(ctx)
case ctx: OperatorAndExpressionContext => astForAndExpression(ctx)
case ctx: OperatorOrExpressionContext => astForOrExpression(ctx)
case ctx: UnaryExpressionContext => Seq(astForUnaryExpression(ctx))
case ctx: PowerExpressionContext => Seq(astForPowerExpression(ctx))
case ctx: UnaryMinusExpressionContext => Seq(astForUnaryMinusExpression(ctx))
case ctx: MultiplicativeExpressionContext => Seq(astForMultiplicativeExpression(ctx))
case ctx: AdditiveExpressionContext => Seq(astForAdditiveExpression(ctx))
case ctx: BitwiseShiftExpressionContext => Seq(astForBitwiseShiftExpression(ctx))
case ctx: BitwiseAndExpressionContext => Seq(astForBitwiseAndExpression(ctx))
case ctx: BitwiseOrExpressionContext => Seq(astForBitwiseOrExpression(ctx))
case ctx: RelationalExpressionContext => Seq(astForRelationalExpression(ctx))
case ctx: EqualityExpressionContext => Seq(astForEqualityExpression(ctx))
case ctx: OperatorAndExpressionContext => Seq(astForAndExpression(ctx))
case ctx: OperatorOrExpressionContext => Seq(astForOrExpression(ctx))
case ctx: RangeExpressionContext => astForRangeExpressionContext(ctx)
case ctx: ConditionalOperatorExpressionContext => Seq(astForTernaryConditionalOperator(ctx))
case ctx: SingleAssignmentExpressionContext => astForSingleAssignmentExpressionContext(ctx)
case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx)
case ctx: IsDefinedExpressionContext => astForIsDefinedExpression(ctx)
case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx))
case _ =>
logger.error(s"astForExpressionContext() $relativeFilename, ${text(ctx)} All contexts mismatched.")
Seq(Ast())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@ import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{
AstNodeNew,
NewCall,
NewIdentifier,
NewMethod,
NewType,
NewTypeDecl
}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext
import org.slf4j.LoggerFactory
Expand All @@ -24,45 +17,45 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
private val logger = LoggerFactory.getLogger(this.getClass)
protected var lastModifier: Option[String] = None

protected def astForPowerExpression(ctx: PowerExpressionContext): Seq[Ast] =
protected def astForPowerExpression(ctx: PowerExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.exponentiation, ctx.expression().asScala)

protected def astForOrExpression(ctx: OperatorOrExpressionContext): Seq[Ast] =
protected def astForOrExpression(ctx: OperatorOrExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.or, ctx.expression().asScala)

protected def astForAndExpression(ctx: OperatorAndExpressionContext): Seq[Ast] =
protected def astForAndExpression(ctx: OperatorAndExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.and, ctx.expression().asScala)

protected def astForUnaryExpression(ctx: UnaryExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForUnaryExpression(ctx: UnaryExpressionContext): Ast = ctx.op.getType match {
case TILDE => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression()))
case PLUS => astForBinaryOperatorExpression(ctx, Operators.plus, Seq(ctx.expression()))
case EMARK => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression()))
}

protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Seq[Ast] =
protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.minus, Seq(ctx.expression()))

protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Ast = ctx.op.getType match {
case PLUS => astForBinaryOperatorExpression(ctx, Operators.addition, ctx.expression().asScala)
case MINUS => astForBinaryOperatorExpression(ctx, Operators.subtraction, ctx.expression().asScala)
}

protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Ast = ctx.op.getType match {
case STAR => astForMultiplicativeStarExpression(ctx)
case SLASH => astForMultiplicativeSlashExpression(ctx)
case PERCENT => astForMultiplicativePercentExpression(ctx)
}

protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.multiplication, ctx.expression().asScala)

protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.division, ctx.expression().asScala)

protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.modulo, ctx.expression().asScala)

protected def astForEqualityExpression(ctx: EqualityExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForEqualityExpression(ctx: EqualityExpressionContext): Ast = ctx.op.getType match {
case LTEQGT => astForBinaryOperatorExpression(ctx, Operators.compare, ctx.expression().asScala)
case EQ2 => astForBinaryOperatorExpression(ctx, Operators.equals, ctx.expression().asScala)
case EQ3 => astForBinaryOperatorExpression(ctx, Operators.is, ctx.expression().asScala)
Expand All @@ -71,22 +64,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case EMARKTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.notPatternMatch, ctx.expression().asScala)
}

protected def astForRelationalExpression(ctx: RelationalExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForRelationalExpression(ctx: RelationalExpressionContext): Ast = ctx.op.getType match {
case GT => astForBinaryOperatorExpression(ctx, Operators.greaterThan, ctx.expression().asScala)
case GTEQ => astForBinaryOperatorExpression(ctx, Operators.greaterEqualsThan, ctx.expression().asScala)
case LT => astForBinaryOperatorExpression(ctx, Operators.lessThan, ctx.expression().asScala)
case LTEQ => astForBinaryOperatorExpression(ctx, Operators.lessEqualsThan, ctx.expression().asScala)
}

protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Ast = ctx.op.getType match {
case BAR => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala)
case CARET => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala)
}

protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Seq[Ast] =
protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, Operators.logicalAnd, ctx.expression().asScala)

protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Seq[Ast] = ctx.op.getType match {
protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Ast = ctx.op.getType match {
case LT2 => astForBinaryOperatorExpression(ctx, Operators.shiftLeft, ctx.expression().asScala)
case GT2 => astForBinaryOperatorExpression(ctx, Operators.logicalShiftRight, ctx.expression().asScala)
}
Expand All @@ -95,19 +88,13 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
ctx: ParserRuleContext,
name: String,
arguments: Iterable[ExpressionContext]
): Seq[Ast] = {
val (argsAst, otherAst) = arguments
.flatMap(astForExpressionContext)
.partition(_.root match
case Some(_: NewMethod) => false
case Some(_: NewTypeDecl) => false
case _ => true
)
val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH)
otherAst.toSeq :+ callAst(call, argsAst.toList)
): Ast = {
val argsAst = arguments.flatMap(astForExpressionContext)
val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH)
callAst(call, argsAst.toList)
}

protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Seq[Ast] =
protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Ast =
astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression()))

// TODO: Maybe merge (in RubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression?
Expand Down Expand Up @@ -375,7 +362,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)
Seq(astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala))

protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = {
val argsAst = Option(ctx.argumentsWithParentheses()) match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,4 @@ class DoBlockTest extends RubyCode2CpgFixture {
}
}

"a do-block function used as a higher-order function" should {
val cpg = code("""class TransactionsController < ApplicationController
| def permitted_column_name(column_name)
| %w[trx_date description amount].find { |permitted| column_name == permitted } || 'trx_date'
| end
|end
|
|""".stripMargin)

"create a do-block method named from the surrounding function" in {
val findMethod :: _ = cpg.method.name("find.*").l: @unchecked
findMethod.name should startWith("find")
findMethod.parameter.size shouldBe 1
val permitParam :: _ = findMethod.parameter.l: @unchecked
permitParam.name shouldBe "permitted"
}

}

}

0 comments on commit a4edfce

Please sign in to comment.