Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rubysrc2cpg] Fixed Bug with Higher-Order Functions #3708

Merged
merged 2 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,23 @@ class AstCreator(

def astForExpressionContext(ctx: ExpressionContext): Seq[Ast] = ctx match {
case ctx: PrimaryExpressionContext => astForPrimaryContext(ctx.primary())
case ctx: UnaryExpressionContext => Seq(astForUnaryExpression(ctx))
case ctx: PowerExpressionContext => Seq(astForPowerExpression(ctx))
case ctx: UnaryMinusExpressionContext => Seq(astForUnaryMinusExpression(ctx))
case ctx: MultiplicativeExpressionContext => Seq(astForMultiplicativeExpression(ctx))
case ctx: AdditiveExpressionContext => Seq(astForAdditiveExpression(ctx))
case ctx: BitwiseShiftExpressionContext => Seq(astForBitwiseShiftExpression(ctx))
case ctx: BitwiseAndExpressionContext => Seq(astForBitwiseAndExpression(ctx))
case ctx: BitwiseOrExpressionContext => Seq(astForBitwiseOrExpression(ctx))
case ctx: RelationalExpressionContext => Seq(astForRelationalExpression(ctx))
case ctx: EqualityExpressionContext => Seq(astForEqualityExpression(ctx))
case ctx: OperatorAndExpressionContext => Seq(astForAndExpression(ctx))
case ctx: OperatorOrExpressionContext => Seq(astForOrExpression(ctx))
case ctx: UnaryExpressionContext => astForUnaryExpression(ctx)
case ctx: PowerExpressionContext => astForPowerExpression(ctx)
case ctx: UnaryMinusExpressionContext => astForUnaryMinusExpression(ctx)
case ctx: MultiplicativeExpressionContext => astForMultiplicativeExpression(ctx)
case ctx: AdditiveExpressionContext => astForAdditiveExpression(ctx)
case ctx: BitwiseShiftExpressionContext => astForBitwiseShiftExpression(ctx)
case ctx: BitwiseAndExpressionContext => astForBitwiseAndExpression(ctx)
case ctx: BitwiseOrExpressionContext => astForBitwiseOrExpression(ctx)
case ctx: RelationalExpressionContext => astForRelationalExpression(ctx)
case ctx: EqualityExpressionContext => astForEqualityExpression(ctx)
case ctx: OperatorAndExpressionContext => astForAndExpression(ctx)
case ctx: OperatorOrExpressionContext => astForOrExpression(ctx)
case ctx: RangeExpressionContext => astForRangeExpressionContext(ctx)
case ctx: ConditionalOperatorExpressionContext => Seq(astForTernaryConditionalOperator(ctx))
case ctx: SingleAssignmentExpressionContext => astForSingleAssignmentExpressionContext(ctx)
case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx)
case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx))
case ctx: IsDefinedExpressionContext => astForIsDefinedExpression(ctx)
case _ =>
logger.error(s"astForExpressionContext() $relativeFilename, ${text(ctx)} All contexts mismatched.")
Seq(Ast())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ import io.joern.rubysrc2cpg.parser.RubyParser.*
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewCall, NewIdentifier, NewMethod}
import io.shiftleft.codepropertygraph.generated.nodes.{
AstNodeNew,
NewCall,
NewIdentifier,
NewMethod,
NewType,
NewTypeDecl
}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import org.antlr.v4.runtime.ParserRuleContext
import org.slf4j.LoggerFactory
Expand All @@ -17,45 +24,45 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
private val logger = LoggerFactory.getLogger(this.getClass)
protected var lastModifier: Option[String] = None

protected def astForPowerExpression(ctx: PowerExpressionContext): Ast =
protected def astForPowerExpression(ctx: PowerExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.exponentiation, ctx.expression().asScala)

protected def astForOrExpression(ctx: OperatorOrExpressionContext): Ast =
protected def astForOrExpression(ctx: OperatorOrExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.or, ctx.expression().asScala)

protected def astForAndExpression(ctx: OperatorAndExpressionContext): Ast =
protected def astForAndExpression(ctx: OperatorAndExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.and, ctx.expression().asScala)

protected def astForUnaryExpression(ctx: UnaryExpressionContext): Ast = ctx.op.getType match {
protected def astForUnaryExpression(ctx: UnaryExpressionContext): Seq[Ast] = ctx.op.getType match {
case TILDE => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression()))
case PLUS => astForBinaryOperatorExpression(ctx, Operators.plus, Seq(ctx.expression()))
case EMARK => astForBinaryOperatorExpression(ctx, Operators.not, Seq(ctx.expression()))
}

protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Ast =
protected def astForUnaryMinusExpression(ctx: UnaryMinusExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.minus, Seq(ctx.expression()))

protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Ast = ctx.op.getType match {
protected def astForAdditiveExpression(ctx: AdditiveExpressionContext): Seq[Ast] = ctx.op.getType match {
case PLUS => astForBinaryOperatorExpression(ctx, Operators.addition, ctx.expression().asScala)
case MINUS => astForBinaryOperatorExpression(ctx, Operators.subtraction, ctx.expression().asScala)
}

protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Ast = ctx.op.getType match {
protected def astForMultiplicativeExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] = ctx.op.getType match {
case STAR => astForMultiplicativeStarExpression(ctx)
case SLASH => astForMultiplicativeSlashExpression(ctx)
case PERCENT => astForMultiplicativePercentExpression(ctx)
}

protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Ast =
protected def astForMultiplicativeStarExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.multiplication, ctx.expression().asScala)

protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Ast =
protected def astForMultiplicativeSlashExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.division, ctx.expression().asScala)

protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Ast =
protected def astForMultiplicativePercentExpression(ctx: MultiplicativeExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.modulo, ctx.expression().asScala)

protected def astForEqualityExpression(ctx: EqualityExpressionContext): Ast = ctx.op.getType match {
protected def astForEqualityExpression(ctx: EqualityExpressionContext): Seq[Ast] = ctx.op.getType match {
case LTEQGT => astForBinaryOperatorExpression(ctx, Operators.compare, ctx.expression().asScala)
case EQ2 => astForBinaryOperatorExpression(ctx, Operators.equals, ctx.expression().asScala)
case EQ3 => astForBinaryOperatorExpression(ctx, Operators.is, ctx.expression().asScala)
Expand All @@ -64,22 +71,22 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case EMARKTILDE => astForBinaryOperatorExpression(ctx, RubyOperators.notPatternMatch, ctx.expression().asScala)
}

protected def astForRelationalExpression(ctx: RelationalExpressionContext): Ast = ctx.op.getType match {
protected def astForRelationalExpression(ctx: RelationalExpressionContext): Seq[Ast] = ctx.op.getType match {
case GT => astForBinaryOperatorExpression(ctx, Operators.greaterThan, ctx.expression().asScala)
case GTEQ => astForBinaryOperatorExpression(ctx, Operators.greaterEqualsThan, ctx.expression().asScala)
case LT => astForBinaryOperatorExpression(ctx, Operators.lessThan, ctx.expression().asScala)
case LTEQ => astForBinaryOperatorExpression(ctx, Operators.lessEqualsThan, ctx.expression().asScala)
}

protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Ast = ctx.op.getType match {
protected def astForBitwiseOrExpression(ctx: BitwiseOrExpressionContext): Seq[Ast] = ctx.op.getType match {
case BAR => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala)
case CARET => astForBinaryOperatorExpression(ctx, Operators.logicalOr, ctx.expression().asScala)
}

protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Ast =
protected def astForBitwiseAndExpression(ctx: BitwiseAndExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, Operators.logicalAnd, ctx.expression().asScala)

protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Ast = ctx.op.getType match {
protected def astForBitwiseShiftExpression(ctx: BitwiseShiftExpressionContext): Seq[Ast] = ctx.op.getType match {
case LT2 => astForBinaryOperatorExpression(ctx, Operators.shiftLeft, ctx.expression().asScala)
case GT2 => astForBinaryOperatorExpression(ctx, Operators.logicalShiftRight, ctx.expression().asScala)
}
Expand All @@ -88,13 +95,19 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
ctx: ParserRuleContext,
name: String,
arguments: Iterable[ExpressionContext]
): Ast = {
val argsAst = arguments.flatMap(astForExpressionContext)
val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH)
callAst(call, argsAst.toList)
): Seq[Ast] = {
val (argsAst, otherAst) = arguments
.flatMap(astForExpressionContext)
.partition(_.root match
case Some(_: NewMethod) => false
case Some(_: NewTypeDecl) => false
case _ => true
)
val call = callNode(ctx, text(ctx), name, name, DispatchTypes.STATIC_DISPATCH)
otherAst.toSeq :+ callAst(call, argsAst.toList)
}

protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Ast =
protected def astForIsDefinedExpression(ctx: IsDefinedExpressionContext): Seq[Ast] =
astForBinaryOperatorExpression(ctx, RubyOperators.defined, Seq(ctx.expression()))

// TODO: Maybe merge (in RubyParser.g4) isDefinedExpression with isDefinedPrimaryExpression?
Expand Down Expand Up @@ -362,7 +375,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

def astForRangeExpressionContext(ctx: RangeExpressionContext): Seq[Ast] =
Seq(astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala))
astForBinaryOperatorExpression(ctx, Operators.range, ctx.expression().asScala)

protected def astForSuperExpression(ctx: SuperExpressionPrimaryContext): Ast = {
val argsAst = Option(ctx.argumentsWithParentheses()) match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,24 @@ class DoBlockTest extends RubyCode2CpgFixture {
}
}

"a do-block function used as a higher-order function" should {
val cpg = code(
"""class TransactionsController < ApplicationController
| def permitted_column_name(column_name)
| %w[trx_date description amount].find { |permitted| column_name == permitted } || 'trx_date'
| end
|end
|
|""".stripMargin)

"create a do-block method named from the surrounding function" in {
val findMethod :: _ = cpg.method.name("find.*").l: @unchecked
findMethod.name should startWith("find")
findMethod.parameter.size shouldBe 1
val permitParam :: _ = findMethod.parameter.l: @unchecked
permitParam.name shouldBe "permitted"
}

}

}