From 5780ed7e900dfb235122d924ac0a3acc6c67e9f4 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 6 Feb 2024 16:20:09 +0100 Subject: [PATCH] feat: apply type parameter substitutions of receiver type for member accesses (#859) Closes partially #23 ### Summary of Changes Let's look at the following example: ``` class C { attr member: T } segment mySegment(p: C) { val a = p.member; ``` Previously, the type of the placeholder `a` was computed as `T` (type parameter type). Now, the type parameter substitutions that were computed for the receiver get applied, so `a` now gets the correct type `Int`. --- .../safe-ds-lang/src/language/typing/model.ts | 10 ++++- .../language/typing/safe-ds-type-computer.ts | 11 ++++- .../tests/language/typing/model.test.ts | 5 +++ .../main.sdstest | 40 +++++++++++++++++++ .../member accesses/to other/main.sdstest | 11 +++++ 5 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/on class with type parameters/main.sdstest diff --git a/packages/safe-ds-lang/src/language/typing/model.ts b/packages/safe-ds-lang/src/language/typing/model.ts index 7043f7b42..dd3fe8b71 100644 --- a/packages/safe-ds-lang/src/language/typing/model.ts +++ b/packages/safe-ds-lang/src/language/typing/model.ts @@ -434,7 +434,15 @@ export class TypeParameterType extends NamedType { } override substituteTypeParameters(substitutions: TypeParameterSubstitutions): Type { - return substitutions.get(this.declaration) ?? this; + const substitution = substitutions.get(this.declaration); + + if (!substitution) { + return this; + } else if (this.isNullable) { + return substitution.updateNullability(true); + } else { + return substitution; + } } override updateNullability(isNullable: boolean): TypeParameterType { diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index 1cb591e23..0c913269b 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -507,7 +507,16 @@ export class SafeDsTypeComputer { } const receiverType = this.computeType(node.receiver); - return memberType.updateNullability((receiverType.isNullable && node.isNullSafe) || memberType.isNullable); + const unsubstitutedResult = memberType.updateNullability( + (receiverType.isNullable && node.isNullSafe) || memberType.isNullable, + ); + + // Substitute type parameters + if (receiverType instanceof ClassType) { + return unsubstitutedResult.substituteTypeParameters(receiverType.substitutions); + } else { + return unsubstitutedResult; + } } private computeTypeOfArithmeticPrefixOperation(node: SdsPrefixOperation): Type { diff --git a/packages/safe-ds-lang/tests/language/typing/model.test.ts b/packages/safe-ds-lang/tests/language/typing/model.test.ts index d72158885..e1dc63efc 100644 --- a/packages/safe-ds-lang/tests/language/typing/model.test.ts +++ b/packages/safe-ds-lang/tests/language/typing/model.test.ts @@ -291,6 +291,11 @@ describe('type model', async () => { substitutions: substitutions1, expectedType: new LiteralType(new IntConstant(1n)), }, + { + type: new TypeParameterType(typeParameter1, true), + substitutions: substitutions1, + expectedType: new LiteralType(new IntConstant(1n), NullConstant), + }, { type: new TypeParameterType(typeParameter2, false), substitutions: substitutions1, diff --git a/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/on class with type parameters/main.sdstest b/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/on class with type parameters/main.sdstest new file mode 100644 index 000000000..f4e36a6e5 --- /dev/null +++ b/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/on class with type parameters/main.sdstest @@ -0,0 +1,40 @@ +package tests.typing.expressions.memberAccesses.onClassWithTypeParameters + +class C { + attr nonNullableMember: T + attr nullableMember: T? + @Pure fun method() -> r: T +} + +@Pure fun nullableC() -> result: C? + +segment mySegment(p: C) { + // $TEST$ serialization Int + »p.nonNullableMember«; + // $TEST$ serialization Int? + »p.nullableMember«; + // $TEST$ serialization () -> (r: Int) + »p.method«; + + // $TEST$ serialization Int + »p?.nonNullableMember«; + // $TEST$ serialization Int? + »p?.nullableMember«; + // $TEST$ serialization () -> (r: Int) + »p?.method«; + + + // $TEST$ serialization Int + »nullableC().nonNullableMember«; + // $TEST$ serialization Int? + »nullableC().nullableMember«; + // $TEST$ serialization () -> (r: Int) + »nullableC().method«; + + // $TEST$ serialization Int? + »nullableC()?.nonNullableMember«; + // $TEST$ serialization Int? + »nullableC()?.nullableMember«; + // $TEST$ serialization union<() -> (r: Int), literal> + »nullableC()?.method«; +} diff --git a/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/to other/main.sdstest b/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/to other/main.sdstest index 0da3c4671..d855a00cc 100644 --- a/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/to other/main.sdstest +++ b/packages/safe-ds-lang/tests/resources/typing/expressions/member accesses/to other/main.sdstest @@ -6,6 +6,9 @@ class C() { // $TEST$ equivalence_class nullableMember attr »nullableMember«: Any? + + // $TEST$ equivalence_class method + @Pure fun »method«() -> r: Int } fun nullableC() -> result: C? @@ -15,20 +18,28 @@ pipeline myPipeline { »C().nonNullableMember«; // $TEST$ equivalence_class nullableMember »C().nullableMember«; + // $TEST$ equivalence_class method + »C().method«; // $TEST$ equivalence_class nonNullableMember »C()?.nonNullableMember«; // $TEST$ equivalence_class nullableMember »C()?.nullableMember«; + // $TEST$ equivalence_class method + »C()?.method«; // $TEST$ equivalence_class nonNullableMember »nullableC().nonNullableMember«; // $TEST$ equivalence_class nullableMember »nullableC().nullableMember«; + // $TEST$ equivalence_class method + »nullableC().method«; // $TEST$ serialization Int? »nullableC()?.nonNullableMember«; // $TEST$ equivalence_class nullableMember »nullableC()?.nullableMember«; + // $TEST$ serialization union<() -> (r: Int), literal> + »nullableC()?.method«; }