From 8b885e1c92798dd74fafab0f0b4a6f950e96af24 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Sat, 4 Jan 2025 17:10:19 +0100 Subject: [PATCH] fix: type inference for lambdas and their parameters --- .../language/typing/safe-ds-type-computer.ts | 91 ++++++++----------- .../src/language/validation/types.ts | 2 +- .../tests/helpers/nodeFinder.test.ts | 4 +- 3 files changed, 39 insertions(+), 58 deletions(-) diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index 03d4bb312..1ac4b3088 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -125,11 +125,11 @@ export class SafeDsTypeComputer { private readonly typeChecker: SafeDsTypeChecker; /** - * Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters, - * they must only access the type of the containing lambda, if they are not contained in this set themselves. - * Otherwise, this would cause endless recursion. + * Contains all calls for which we currently compute substitutions. This prevents endless recursion, since the + * substitutions of a call depend on the inferred types of their arguments, which may be lambdas. The inferred type + * of a lambda in turn depends on the substitutions of the call it is passed to. */ - private readonly incompleteLambdaParameters = new Set(); + private readonly incompleteCalls = new Set(); private readonly nodeTypeCache: WorkspaceCache; constructor(services: SafeDsServices) { @@ -301,18 +301,21 @@ export class SafeDsTypeComputer { // Lambda passed as argument if (isSdsArgument(containerOfLambda)) { - // Lookup parameter type in lambda unless the lambda is being computed. These contain the correct - // substitutions for type parameters. - if (!this.incompleteLambdaParameters.has(node)) { - return this.computeType(containingCallable); - } - const parameter = this.nodeMapper.argumentToParameter(containerOfLambda); if (!parameter) { return UnknownType; } - return this.computeType(parameter); + const unsubstitutedType = this.computeType(parameter); + + // Substitute type parameters + const call = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); + if (call) { + const substitutions = this.computeSubstitutionsForCall(call, containerOfLambda.$containerIndex); + return unsubstitutedType.substituteTypeParameters(substitutions); + } + + return unsubstitutedType; } // Lambda passed as default value @@ -569,29 +572,16 @@ export class SafeDsTypeComputer { } private computeTypeOfLambda(node: SdsLambda): Type { - // Remember lambda parameters const parameters = getParameters(node); - parameters.forEach((it) => { - this.incompleteLambdaParameters.add(it); - }); - const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))); const resultEntries = this.buildLambdaResultEntries(node); - const unsubstitutedType = this.factory.createCallableType( + return this.factory.createCallableType( node, undefined, this.factory.createNamedTupleType(...parameterEntries), this.factory.createNamedTupleType(...resultEntries), ); - const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); - - // Forget lambda parameters - parameters.forEach((it) => { - this.incompleteLambdaParameters.delete(it); - }); - - return unsubstitutedType.substituteTypeParameters(substitutions); } private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry[] { @@ -843,16 +833,17 @@ export class SafeDsTypeComputer { /** * Computes substitutions for the type parameters of a callable in the context of a call. * - * @param node The call to compute substitutions for. + * @param node + * The call to compute substitutions for. + * @param argumentEndIndex + * The index of the first argument that should not be considered for the computation. If not specified, all + * arguments are considered. + * * @returns The computed substitutions for the type parameters of the callable. */ - computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions { - return this.doComputeSubstitutionsForCall(node); - } - - private doComputeSubstitutionsForCall( + computeSubstitutionsForCall( node: SdsAbstractCall, - precomputedArgumentTypes?: Map, + argumentEndIndex: number | undefined = undefined, ): TypeParameterSubstitutions { // Compute substitutions for member access const substitutionsFromReceiver = @@ -860,6 +851,14 @@ export class SafeDsTypeComputer { ? this.computeSubstitutionsForMemberAccess(node.receiver) : NO_SUBSTITUTIONS; + // Check if the call is already being computed + if (this.incompleteCalls.has(node)) { + return substitutionsFromReceiver; + } + + // Remember call + this.incompleteCalls.add(node); + // Compute substitutions for arguments const callable = this.nodeMapper.callToCallable(node); const typeParameters = getTypeParameters(callable); @@ -868,17 +867,12 @@ export class SafeDsTypeComputer { } const parameters = getParameters(callable); - const args = getArguments(node); + const args = getArguments(node).slice(0, argumentEndIndex); const parametersToArguments = this.nodeMapper.parametersToArguments(parameters, args); const parameterTypesToArgumentTypes: [Type, Type][] = parameters.map((parameter) => { const argument = parametersToArguments.get(parameter); - return [ - this.computeType(parameter.type), - // Use precomputed argument types (lambdas) if available. This prevents infinite recursion. - precomputedArgumentTypes?.get(argument?.value) ?? - this.computeType(argument?.value ?? parameter.defaultValue), - ]; + return [this.computeType(parameter.type), this.computeType(argument?.value ?? parameter.defaultValue)]; }); const substitutionsFromArguments = this.computeSubstitutionsForArguments( @@ -886,6 +880,9 @@ export class SafeDsTypeComputer { parameterTypesToArgumentTypes, ); + // Forget call + this.incompleteCalls.delete(node); + return new Map([...substitutionsFromReceiver, ...substitutionsFromArguments]); } @@ -918,22 +915,6 @@ export class SafeDsTypeComputer { return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes); } - private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions { - const containerOfLambda = node.$container; - if (!isSdsArgument(containerOfLambda)) { - return NO_SUBSTITUTIONS; - } - - const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); - if (!containingCall) { - /* c8 ignore next 2 */ - return NO_SUBSTITUTIONS; - } - - const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]); - return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes); - } - private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions { const receiverType = this.computeType(node.receiver); if (receiverType instanceof ClassType) { diff --git a/packages/safe-ds-lang/src/language/validation/types.ts b/packages/safe-ds-lang/src/language/validation/types.ts index caf9ee2a0..8aa31ed21 100644 --- a/packages/safe-ds-lang/src/language/validation/types.ts +++ b/packages/safe-ds-lang/src/language/validation/types.ts @@ -51,7 +51,7 @@ export const argumentTypesMustMatchParameterTypes = (services: SafeDsServices) = return; } - const argumentType = typeComputer.computeType(argument).substituteTypeParameters(substitutions); + const argumentType = typeComputer.computeType(argument); const parameterType = typeComputer.computeType(parameter).substituteTypeParameters(substitutions); if (!typeChecker.isSubtypeOf(argumentType, parameterType, { ignoreParameterNames: true })) { diff --git a/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts b/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts index 6b646a6ea..6208ae3b1 100644 --- a/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts +++ b/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts @@ -65,14 +65,14 @@ describe('getNodeOfType', async () => { it('should throw if no node is found', async () => { const code = ''; - expect(async () => { + await expect(async () => { await getNodeOfType(services, code, isSdsClass); }).rejects.toThrowErrorMatchingSnapshot(); }); it('should throw if not enough nodes are found', async () => { const code = `class C`; - expect(async () => { + await expect(async () => { await getNodeOfType(services, code, isSdsClass, 1); }).rejects.toThrowErrorMatchingSnapshot(); });