diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index e976fc14d260..707fbb3354b3 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -372,10 +372,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForYield(node: YieldExpr): Ast = { scope.useProcParam match { case Some(param) => - astForReturnStatement( - ReturnExpression(List(SimpleCall(SimpleIdentifier()(node.span.spanStart(param)), node.arguments)(node.span)))( - node.span - ) + val call = astForExpression( + SimpleCall(SimpleIdentifier()(node.span.spanStart(param)), node.arguments)(node.span) + ) + val ret = returnAst(returnNode(node, code(node))) + val cond = astForExpression( + SimpleCall(SimpleIdentifier()(node.span.spanStart(tmpGen.fresh)), List())(node.span.spanStart("")) + ) + callAst( + callNode(node, code(node), Operators.conditional, Operators.conditional, DispatchTypes.STATIC_DISPATCH), + List(cond, call, ret) ) case None => logger.warn(s"Yield expression outside of method scope: ${code(node)} ($relativeFileName), skipping") diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 812f141df392..96bf0ead605a 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -252,7 +252,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case node: MemberCallWithBlock => returnAstForRubyCall(node) case node: SimpleCallWithBlock => returnAstForRubyCall(node) case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | IndexAccess | Association | - RubyCall) => + YieldExpr | RubyCall) => astForReturnStatement(ReturnExpression(List(node))(node.span)) :: Nil case node: SingleAssignment => astForSingleAssignment(node) :: List(astForReturnStatement(ReturnExpression(List(node.lhs))(node.span))) @@ -266,9 +266,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case node: MethodDeclaration => (astForMethodDeclaration(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList - case node: YieldExpr => // Yield is already a return expression to handle do block returns - astForExpression(node) :: Nil - case node => logger.warn( s"Implicit return here not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement" diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala index 2b86db295f50..d03d7d33239b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ProcParameterAndYieldTests.scala @@ -3,6 +3,7 @@ package io.joern.rubysrc2cpg.querying import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import org.scalatest.Inspectors import io.shiftleft.semanticcpg.language.* +import io.shiftleft.codepropertygraph.generated.nodes.* class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { "Methods" should { @@ -16,8 +17,17 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { forAll(cpgs)(_.method("foo").parameter.code("&.*").name.l shouldBe List("b")) } - "replace the yield with a call to the block parameter" in { - forAll(cpgs)(_.call.code("yield").name.l shouldBe List("b")) + "represent the yield as a conditional with a call and return node as children" in { + forAll(cpgs) { cpg => + inside(cpg.method("foo").call(".conditional").code("yield").astChildren.l) { + case List(cond: Expression, call: Call, ret: Return) => { + cond.code shouldBe "" + call.name shouldBe "b" + call.code shouldBe "yield" + ret.code shouldBe "yield" + } + } + } } } @@ -26,8 +36,8 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { val cpg2 = code("def self.foo() yield end") val cpgs = List(cpg1, cpg2) - "replace the yield with a call to a block parameter" in { - forAll(cpgs)(_.call.code("yield").name.l shouldBe List("")) + "have a call to a block parameter" in { + forAll(cpgs)(_.call.code("yield").astChildren.isCall.code("yield").name.l shouldBe List("")) } "add a block argument" in { @@ -42,8 +52,8 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { "with yield arguments" should { val cpg = code("def foo(x) yield(x) end") "replace the yield with a call to the block parameter with arguments" in { - val List(call) = cpg.call.codeExact("yield(x)").l - call.name shouldBe "b" + val List(call) = cpg.call.codeExact("yield(x)").astChildren.isCall.codeExact("yield(x)").l + call.name shouldBe "" call.argument.code.l shouldBe List("x") } @@ -60,31 +70,6 @@ class ProcParameterAndYieldTests extends RubyCode2CpgFixture with Inspectors { } } - "with non-implicitly returned yield" should { - val cpg = code(""" - |def foo() - | yield - | 1 - |end - |""".stripMargin) - "have a return node for the yield" in { - val List(ret) = cpg.method("foo").call.code("yield").astParent.isReturn.l - ret.code shouldBe "yield" - } - } - - "with implicitly returned yield" should { - val cpg = code(""" - |def foo() - | yield - |end - |""".stripMargin) - "have a return node for the yield" in { - val List(ret) = cpg.method("foo").call.code("yield").astParent.isReturn.l - ret.code shouldBe "yield" - } - } - } }