Skip to content

Commit

Permalink
[C#] Handling for null-forgiving expressions, and refactors (#4350)
Browse files Browse the repository at this point in the history
Includes, handling for null-forgiving expressions. Refactor Member Access as parser emits the same token for multiple cases other than accessing class members alone. 

This is a follow-up PR for #4345. Please merge #4345 before this one.
Resolves #4349
  • Loading branch information
karan-batavia authored Mar 19, 2024
1 parent 7fffff2 commit bfbc9c8
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object AstCreatorHelper {
case IdentifierName | Parameter | _: DeclarationExpr | GenericName =>
nameFromIdentifier(node)
case QualifiedName => nameFromQualifiedName(node)
case SimpleMemberAccessExpression | MemberBindingExpression =>
case SimpleMemberAccessExpression | MemberBindingExpression | SuppressNullableWarningExpression =>
nameFromIdentifier(createDotNetNodeInfo(node.json(ParserKeys.Name)))
case ObjectCreationExpression | CastExpression => nameFromNode(createDotNetNodeInfo(node.json(ParserKeys.Type)))
case ThisExpression => "this"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewLiteral, NewTypeRef}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import ujson.Value
import io.joern.csharpsrc2cpg.Constants

import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
Expand All @@ -21,22 +22,23 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {

def astForExpression(expr: DotNetNodeInfo): Seq[Ast] = {
expr.node match
case _: UnaryExpr => astForUnaryExpression(expr)
case _: BinaryExpr => astForBinaryExpression(expr)
case _: LiteralExpr => astForLiteralExpression(expr)
case InvocationExpression => astForInvocationExpression(expr)
case AwaitExpression => astForAwaitExpression(expr)
case ObjectCreationExpression => astForObjectCreationExpression(expr)
case SimpleMemberAccessExpression => astForSimpleMemberAccessExpression(expr)
case ImplicitArrayCreationExpression => astForImplicitArrayCreationExpression(expr)
case ConditionalExpression => astForConditionalExpression(expr)
case _: IdentifierNode => astForIdentifier(expr) :: Nil
case ThisExpression => astForThisReceiver(expr) :: Nil
case CastExpression => astForCastExpression(expr)
case InterpolatedStringExpression => astForInterpolatedStringExpression(expr)
case ConditionalAccessExpression => astForConditionalAccessExpression(expr)
case _: BaseLambdaExpression => astForSimpleLambdaExpression(expr)
case _ => notHandledYet(expr)
case _: UnaryExpr => astForUnaryExpression(expr)
case _: BinaryExpr => astForBinaryExpression(expr)
case _: LiteralExpr => astForLiteralExpression(expr)
case InvocationExpression => astForInvocationExpression(expr)
case AwaitExpression => astForAwaitExpression(expr)
case ObjectCreationExpression => astForObjectCreationExpression(expr)
case SimpleMemberAccessExpression => astForSimpleMemberAccessExpression(expr)
case ImplicitArrayCreationExpression => astForImplicitArrayCreationExpression(expr)
case ConditionalExpression => astForConditionalExpression(expr)
case _: IdentifierNode => astForIdentifier(expr) :: Nil
case ThisExpression => astForThisReceiver(expr) :: Nil
case CastExpression => astForCastExpression(expr)
case InterpolatedStringExpression => astForInterpolatedStringExpression(expr)
case ConditionalAccessExpression => astForConditionalAccessExpression(expr)
case SuppressNullableWarningExpression => astForSuppressNullableWarningExpression(expr)
case _: BaseLambdaExpression => astForSimpleLambdaExpression(expr)
case _ => notHandledYet(expr)
}

private def astForAwaitExpression(awaitExpr: DotNetNodeInfo): Seq[Ast] = {
Expand Down Expand Up @@ -210,7 +212,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
methodMetaData: Option[CSharpMethod],
arguments: Seq[Ast]
) = expression.node match {
case SimpleMemberAccessExpression =>
case SimpleMemberAccessExpression | SuppressNullableWarningExpression =>
val baseNode = createDotNetNodeInfo(
createDotNetNodeInfo(invocationExpr.json(ParserKeys.Expression)).json(ParserKeys.Expression)
)
Expand Down Expand Up @@ -257,6 +259,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case Some(m) => s"${m.returnType}(${m.parameterTypes.filterNot(_._1 == "this").map(_._2).mkString(",")})"
case None => Defines.UnresolvedSignature
}

val methodFullName = baseTypeFullName match {
case Some(typeFullName) =>
s"$typeFullName.$callName:$methodSignature"
Expand Down Expand Up @@ -290,12 +293,42 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
protected def astForSimpleMemberAccessExpression(accessExpr: DotNetNodeInfo): Seq[Ast] = {
val fieldIdentifierName = nameFromNode(accessExpr)

val fieldInScope = scope.findFieldInScope(fieldIdentifierName)

val typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any)
val identifierName =
if (fieldInScope.nonEmpty && fieldInScope.get.isStatic) scope.surroundingTypeDeclFullName.getOrElse(Defines.Any)
else "this"
val (identifierName, typeFullName) = accessExpr.node match {
case SimpleMemberAccessExpression => {
createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)).node match
case SuppressNullableWarningExpression =>
val baseNode = createDotNetNodeInfo(accessExpr.json(ParserKeys.Expression)(ParserKeys.Operand))
val baseAst = astForNode(baseNode)
val baseTypeFullName = getTypeFullNameFromAstNode(baseAst)

val fieldInScope = scope.tryResolveFieldAccess(fieldIdentifierName, typeFullName = Option(baseTypeFullName))

(
nameFromNode(baseNode),
fieldInScope
.map(_.typeName)
.getOrElse(Defines.Any)
)
case _ => {
val fieldInScope = scope.findFieldInScope(fieldIdentifierName)
val _identifierName =
if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true))
scope.surroundingTypeDeclFullName.getOrElse(Defines.Any)
else Constants.This
val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any)
(_identifierName, _typeFullName)
}
}
case _ => {
val fieldInScope = scope.findFieldInScope(fieldIdentifierName)
val _identifierName =
if (fieldInScope.nonEmpty && fieldInScope.map(_.isStatic).contains(true))
scope.surroundingTypeDeclFullName.getOrElse(Defines.Any)
else Constants.This
val _typeFullName = fieldInScope.map(_.typeFullName).getOrElse(Defines.Any)
(_identifierName, _typeFullName)
}
}

val identifier = newIdentifierNode(identifierName, typeFullName)

Expand Down Expand Up @@ -487,4 +520,8 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}
}

private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = {
val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand))
Seq(astForIdentifier(_identifierNode))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ object DotNetJsonAst {

object MemberBindingExpression extends BaseExpr

object SuppressNullableWarningExpression extends BaseExpr

object Unknown extends DotNetParserNode

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,77 @@ class MemberAccessTests extends CSharpCode2CpgFixture {
}
}
}

"null-forgiving member access expressions" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public int Qux {get;}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var a = baz!.Qux;
| }
| }
|}
|""".stripMargin)

"have correct types both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1)) {
case a :: Nil =>
inside(a.argument.l) { case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
}
case _ => fail("Expected 1 assignment call.")
}
}
}

"null-forgiving method access expressions" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
| public int Qux() {}
| public string Fred(int a) {}
| }
| public class Bar {
| public static void Main() {
| var baz = new Baz();
| var a = baz!.Qux();
| var b = baz!.Fred(1);
| }
| }
|}
|""".stripMargin)

"have correct types and attributes both on the LHS and RHS" in {
inside(cpg.assignment.l.sortBy(_.lineNumber).drop(1)) {
case a :: b :: Nil =>
inside(a.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.Int)
case _ => fail("Expected 2 arguments under the assignment call")
}

inside(b.argument.l) {
case (lhs: Identifier) :: (rhs: Call) :: Nil =>
lhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String)
rhs.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String)
case _ => fail("Expected 2 arguments under the assignment call")
}

inside(cpg.call.nameExact("Fred").l) {
case fred :: Nil =>
fred.typeFullName shouldBe BuiltinTypes.DotNetTypeMap(BuiltinTypes.String)
fred.methodFullName shouldBe "Foo.Baz.Fred:System.String(System.Int32)"
case _ => fail("Expected a call named `Fred`")

}
case _ => fail("Expected 2 assignment call.")
}
}
}
}

0 comments on commit bfbc9c8

Please sign in to comment.