From 3200b19660e5e83b687574aa79f664e89fdd4257 Mon Sep 17 00:00:00 2001 From: Andrei Dreyer Date: Tue, 3 Dec 2024 11:40:33 +0200 Subject: [PATCH] Added implicit return handling for MethodAccessModifiers (#5150) --- .../astcreation/AstForStatementsCreator.scala | 17 +++++++++++ .../astcreation/RubyIntermediateAst.scala | 9 ++++-- .../rubysrc2cpg/querying/ClassTests.scala | 28 +++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index b71ecbb51290..214938a7fef8 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -240,6 +240,23 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t val simpleIdent = node.toSimpleIdentifier val simpleCall = SimpleCall(simpleIdent, List.empty)(simpleIdent.span) astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil + case node: MethodAccessModifier => + val simpleIdent = node.toSimpleIdentifier + + val methodIdentName = node.method match { + case x: StaticLiteral => x.span.text + case x: MethodDeclaration => x.methodName + case x => + logger.warn(s"Unknown node type for method identifier name: ${x.getClass} (${this.relativeFileName})") + x.span.text + } + + val methodIdent = SimpleIdentifier(None)(simpleIdent.span.spanStart(methodIdentName)) + + val simpleCall = SimpleCall(simpleIdent, List(methodIdent))( + simpleIdent.span.spanStart(s"${simpleIdent.span.text} ${methodIdent.span.text}") + ) + astForReturnExpression(ReturnExpression(List(simpleCall))(node.span)) :: Nil case node: FieldsDeclaration => val nilReturnSpan = node.span.spanStart("return nil") val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 3f1ee8b86154..406893179b8e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -489,6 +489,7 @@ object RubyIntermediateAst { } sealed trait MethodAccessModifier extends AllowedTypeDeclarationChild { + def toSimpleIdentifier: SimpleIdentifier def method: RubyExpression } @@ -506,11 +507,15 @@ object RubyIntermediateAst { final case class PrivateMethodModifier(method: RubyExpression)(span: TextSpan) extends RubyExpression(span) - with MethodAccessModifier + with MethodAccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span.spanStart("private_class_method")) + } final case class PublicMethodModifier(method: RubyExpression)(span: TextSpan) extends RubyExpression(span) - with MethodAccessModifier + with MethodAccessModifier { + override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span.spanStart("public_class_method")) + } /** Represents standalone `proc { ... }` or `lambda { ... }` expressions */ diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala index 8d83e57c65d7..5208f203f472 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala @@ -1215,4 +1215,32 @@ class ClassTests extends RubyCode2CpgFixture { } } } + + "Implicit return for call to `private_class_method`" in { + val cpg = code(""" + |class Foo + | def case_sensitive_find_by() + | end + | + | included do + | private_class_method :case_sensitive_find_by + | end + |end + |""".stripMargin) + + inside(cpg.typeDecl.name("Foo").astChildren.isMethod.l) { + case lambdaMethod :: _ :: _ :: _ :: Nil => + val List(lambdaReturn) = lambdaMethod.body.astChildren.isReturn.l + + lambdaReturn.code shouldBe "private_class_method :case_sensitive_find_by" + + val List(returnCall) = lambdaReturn.astChildren.isCall.l + returnCall.code shouldBe "private_class_method :case_sensitive_find_by" + + val List(_, methodNameArg) = returnCall.argument.l + methodNameArg.code shouldBe "self.:case_sensitive_find_by" + + case xs => fail(s"Expected 5 methods, got [${xs.code.mkString(",")}]") + } + } }