diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala index f7e3956f1343..33df482e9045 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala @@ -183,7 +183,7 @@ object AstCreatorHelper { case IdentifierName | Parameter | _: DeclarationExpr | GenericName => nameFromIdentifier(node) case QualifiedName => nameFromQualifiedName(node) - case SimpleMemberAccessExpression | MemberBindingExpression => + case SimpleMemberAccessExpression | MemberBindingExpression | SuppressNullableWarningExpression => nameFromIdentifier(createDotNetNodeInfo(node.json(ParserKeys.Name))) case ObjectCreationExpression | CastExpression => nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Type))) case ThisExpression => "this" diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index 03fe53879700..2a1819a0206a 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -9,6 +9,7 @@ import io.joern.x2cpg.{Ast, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewLiteral, NewTypeRef} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import ujson.Value +import io.joern.csharpsrc2cpg.Constants import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} @@ -21,22 +22,23 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { def astForExpression(expr: DotNetNodeInfo): Seq[Ast] = { expr.node match - case _: UnaryExpr => astForUnaryExpression(expr) - case _: BinaryExpr => astForBinaryExpression(expr) - case _: LiteralExpr => astForLiteralExpression(expr) - case InvocationExpression => astForInvocationExpression(expr) - case AwaitExpression => astForAwaitExpression(expr) - case ObjectCreationExpression => astForObjectCreationExpression(expr) - case SimpleMemberAccessExpression => astForSimpleMemberAccessExpression(expr) - case ImplicitArrayCreationExpression => astForImplicitArrayCreationExpression(expr) - case ConditionalExpression => astForConditionalExpression(expr) - case _: IdentifierNode => astForIdentifier(expr) :: Nil - case ThisExpression => astForThisReceiver(expr) :: Nil - case CastExpression => astForCastExpression(expr) - case InterpolatedStringExpression => astForInterpolatedStringExpression(expr) - case ConditionalAccessExpression => astForConditionalAccessExpression(expr) - case _: BaseLambdaExpression => astForSimpleLambdaExpression(expr) - case _ => notHandledYet(expr) + case _: UnaryExpr => astForUnaryExpression(expr) + case _: BinaryExpr => astForBinaryExpression(expr) + case _: LiteralExpr => astForLiteralExpression(expr) + case InvocationExpression => astForInvocationExpression(expr) + case AwaitExpression => astForAwaitExpression(expr) + case ObjectCreationExpression => astForObjectCreationExpression(expr) + case SimpleMemberAccessExpression => astForSimpleMemberAccessExpression(expr) + case ImplicitArrayCreationExpression => astForImplicitArrayCreationExpression(expr) + case ConditionalExpression => astForConditionalExpression(expr) + case _: IdentifierNode => astForIdentifier(expr) :: Nil + case ThisExpression => astForThisReceiver(expr) :: Nil + case CastExpression => astForCastExpression(expr) + case InterpolatedStringExpression => astForInterpolatedStringExpression(expr) + case ConditionalAccessExpression => astForConditionalAccessExpression(expr) + case SuppressNullableWarningExpression => astForSuppressNullableWarningExpression(expr) + case _: BaseLambdaExpression => astForSimpleLambdaExpression(expr) + case _ => notHandledYet(expr) } private def astForAwaitExpression(awaitExpr: DotNetNodeInfo): Seq[Ast] = { @@ -210,7 +212,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { methodMetaData: Option[CSharpMethod], arguments: Seq[Ast] ) = expression.node match { - case SimpleMemberAccessExpression => + case SimpleMemberAccessExpression | SuppressNullableWarningExpression => val baseNode = createDotNetNodeInfo( createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)).json(ParserKeys.Expression) ) @@ -257,6 +259,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case Some(m) => s"${m.returnType}(${m.parameterTypes.filterNot(_._1 == "this").map(_._2).mkString(",")})" case None => Defines.UnresolvedSignature } + val methodFullName = baseTypeFullName match { case Some(typeFullName) => s"$typeFullName.$callName:$methodSignature" @@ -290,12 +293,42 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForSimpleMemberAccessExpression(accessExpr: DotNetNodeInfo): Seq[Ast] = { val fieldIdentifierName = nameFromNode(accessExpr) - val fieldInScope = scope.findFieldInScope(fieldIdentifierName) - - val typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any) - val identifierName = - if (fieldInScope.nonEmpty && fieldInScope.get.isStatic) scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) - else "this" + val (identifierName, typeFullName) = accessExpr.node match { + case SimpleMemberAccessExpression => { + createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)).node match + case SuppressNullableWarningExpression => + val baseNode = createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)(ParserKeys.Operand)) + val baseAst = astForNode(baseNode) + val baseTypeFullName = getTypeFullNameFromAstNode(baseAst) + + val fieldInScope = scope.tryResolveFieldAccess(fieldIdentifierName, typeFullName = Option(baseTypeFullName)) + + ( + nameFromNode(baseNode), + fieldInScope + .map(_.typeName) + .getOrElse(Defines.Any) + ) + case _ => { + val fieldInScope = scope.findFieldInScope(fieldIdentifierName) + val _identifierName = + if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true)) + scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) + else Constants.This + val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any) + (_identifierName, _typeFullName) + } + } + case _ => { + val fieldInScope = scope.findFieldInScope(fieldIdentifierName) + val _identifierName = + if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true)) + scope.surroundingTypeDeclFullName.getOrElse(Defines.Any) + else Constants.This + val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any) + (_identifierName, _typeFullName) + } + } val identifier = newIdentifierNode(identifierName, typeFullName) @@ -487,4 +520,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } } + private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = { + val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand)) + Seq(astForIdentifier(_identifierNode)) + } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala index 8702af90cc8d..b0ad7bee9dd9 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala @@ -248,6 +248,8 @@ object DotNetJsonAst { object MemberBindingExpression extends BaseExpr + object SuppressNullableWarningExpression extends BaseExpr + object Unknown extends DotNetParserNode } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala index 46d40be29a7c..2240ebda649b 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala @@ -80,4 +80,77 @@ class MemberAccessTests extends CSharpCode2CpgFixture { } } } + + "null-forgiving member access expressions" should { + val cpg = code(""" + |namespace Foo { + | public class Baz { + | public int Qux {get;} + | } + | public class Bar { + | public static void Main() { + | var baz = new Baz(); + | var a = baz!.Qux; + | } + | } + |} + |""".stripMargin) + + "have correct types both on the LHS and RHS" in { + inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1)) { + case a :: Nil => + inside(a.argument.l) { case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + } + case _ => fail("Expected 1 assignment call.") + } + } + } + + "null-forgiving method access expressions" should { + val cpg = code(""" + |namespace Foo { + | public class Baz { + | public int Qux() {} + | public string Fred(int a) {} + | } + | public class Bar { + | public static void Main() { + | var baz = new Baz(); + | var a = baz!.Qux(); + | var b = baz!.Fred(1); + | } + | } + |} + |""".stripMargin) + + "have correct types and attributes both on the LHS and RHS" in { + inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1)) { + case a :: b :: Nil => + inside(a.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int) + case _ => fail("Expected 2 arguments under the assignment call") + } + + inside(b.argument.l) { + case (lhs: Identifier) :: (rhs: Call) :: Nil => + lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String) + rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String) + case _ => fail("Expected 2 arguments under the assignment call") + } + + inside(cpg.call.nameExact("Fred").l) { + case fred :: Nil => + fred.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String) + fred.methodFullName shouldBe "Foo.Baz.Fred:System.String(System.Int32)" + case _ => fail("Expected a call named `Fred`") + + } + case _ => fail("Expected 2 assignment call.") + } + } + } }