Skip to content

Commit

Permalink
fix: type inference for lambdas and their parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-reimann committed Jan 4, 2025
1 parent 9c4e70c commit 8b885e1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 58 deletions.
91 changes: 36 additions & 55 deletions packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ export class SafeDsTypeComputer {
private readonly typeChecker: SafeDsTypeChecker;

/**
* Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters,
* they must only access the type of the containing lambda, if they are not contained in this set themselves.
* Otherwise, this would cause endless recursion.
* Contains all calls for which we currently compute substitutions. This prevents endless recursion, since the
* substitutions of a call depend on the inferred types of their arguments, which may be lambdas. The inferred type
* of a lambda in turn depends on the substitutions of the call it is passed to.
*/
private readonly incompleteLambdaParameters = new Set<SdsParameter>();
private readonly incompleteCalls = new Set<SdsAbstractCall>();
private readonly nodeTypeCache: WorkspaceCache<string, Type>;

constructor(services: SafeDsServices) {
Expand Down Expand Up @@ -301,18 +301,21 @@ export class SafeDsTypeComputer {

// Lambda passed as argument
if (isSdsArgument(containerOfLambda)) {
// Lookup parameter type in lambda unless the lambda is being computed. These contain the correct
// substitutions for type parameters.
if (!this.incompleteLambdaParameters.has(node)) {
return this.computeType(containingCallable);
}

const parameter = this.nodeMapper.argumentToParameter(containerOfLambda);
if (!parameter) {
return UnknownType;
}

return this.computeType(parameter);
const unsubstitutedType = this.computeType(parameter);

// Substitute type parameters
const call = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (call) {
const substitutions = this.computeSubstitutionsForCall(call, containerOfLambda.$containerIndex);
return unsubstitutedType.substituteTypeParameters(substitutions);
}

return unsubstitutedType;

Check warning on line 318 in packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts

View check run for this annotation

Codecov / codecov/patch

packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts#L318

Added line #L318 was not covered by tests
}

// Lambda passed as default value
Expand Down Expand Up @@ -569,29 +572,16 @@ export class SafeDsTypeComputer {
}

private computeTypeOfLambda(node: SdsLambda): Type {
// Remember lambda parameters
const parameters = getParameters(node);
parameters.forEach((it) => {
this.incompleteLambdaParameters.add(it);
});

const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it)));
const resultEntries = this.buildLambdaResultEntries(node);

const unsubstitutedType = this.factory.createCallableType(
return this.factory.createCallableType(
node,
undefined,
this.factory.createNamedTupleType(...parameterEntries),
this.factory.createNamedTupleType(...resultEntries),
);
const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType);

// Forget lambda parameters
parameters.forEach((it) => {
this.incompleteLambdaParameters.delete(it);
});

return unsubstitutedType.substituteTypeParameters(substitutions);
}

private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry<SdsAbstractResult>[] {
Expand Down Expand Up @@ -843,23 +833,32 @@ export class SafeDsTypeComputer {
/**
* Computes substitutions for the type parameters of a callable in the context of a call.
*
* @param node The call to compute substitutions for.
* @param node
* The call to compute substitutions for.
* @param argumentEndIndex
* The index of the first argument that should not be considered for the computation. If not specified, all
* arguments are considered.
*
* @returns The computed substitutions for the type parameters of the callable.
*/
computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions {
return this.doComputeSubstitutionsForCall(node);
}

private doComputeSubstitutionsForCall(
computeSubstitutionsForCall(
node: SdsAbstractCall,
precomputedArgumentTypes?: Map<AstNode | undefined, Type>,
argumentEndIndex: number | undefined = undefined,
): TypeParameterSubstitutions {
// Compute substitutions for member access
const substitutionsFromReceiver =
isSdsCall(node) && isSdsMemberAccess(node.receiver)
? this.computeSubstitutionsForMemberAccess(node.receiver)
: NO_SUBSTITUTIONS;

// Check if the call is already being computed
if (this.incompleteCalls.has(node)) {
return substitutionsFromReceiver;
}

// Remember call
this.incompleteCalls.add(node);

// Compute substitutions for arguments
const callable = this.nodeMapper.callToCallable(node);
const typeParameters = getTypeParameters(callable);
Expand All @@ -868,24 +867,22 @@ export class SafeDsTypeComputer {
}

const parameters = getParameters(callable);
const args = getArguments(node);
const args = getArguments(node).slice(0, argumentEndIndex);

const parametersToArguments = this.nodeMapper.parametersToArguments(parameters, args);
const parameterTypesToArgumentTypes: [Type, Type][] = parameters.map((parameter) => {
const argument = parametersToArguments.get(parameter);
return [
this.computeType(parameter.type),
// Use precomputed argument types (lambdas) if available. This prevents infinite recursion.
precomputedArgumentTypes?.get(argument?.value) ??
this.computeType(argument?.value ?? parameter.defaultValue),
];
return [this.computeType(parameter.type), this.computeType(argument?.value ?? parameter.defaultValue)];
});

const substitutionsFromArguments = this.computeSubstitutionsForArguments(
typeParameters,
parameterTypesToArgumentTypes,
);

// Forget call
this.incompleteCalls.delete(node);

return new Map([...substitutionsFromReceiver, ...substitutionsFromArguments]);
}

Expand Down Expand Up @@ -918,22 +915,6 @@ export class SafeDsTypeComputer {
return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes);
}

private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions {
const containerOfLambda = node.$container;
if (!isSdsArgument(containerOfLambda)) {
return NO_SUBSTITUTIONS;
}

const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (!containingCall) {
/* c8 ignore next 2 */
return NO_SUBSTITUTIONS;
}

const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]);
return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes);
}

private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions {
const receiverType = this.computeType(node.receiver);
if (receiverType instanceof ClassType) {
Expand Down
2 changes: 1 addition & 1 deletion packages/safe-ds-lang/src/language/validation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export const argumentTypesMustMatchParameterTypes = (services: SafeDsServices) =
return;
}

const argumentType = typeComputer.computeType(argument).substituteTypeParameters(substitutions);
const argumentType = typeComputer.computeType(argument);
const parameterType = typeComputer.computeType(parameter).substituteTypeParameters(substitutions);

if (!typeChecker.isSubtypeOf(argumentType, parameterType, { ignoreParameterNames: true })) {
Expand Down
4 changes: 2 additions & 2 deletions packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ describe('getNodeOfType', async () => {

it('should throw if no node is found', async () => {
const code = '';
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass);
}).rejects.toThrowErrorMatchingSnapshot();
});

it('should throw if not enough nodes are found', async () => {
const code = `class C`;
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass, 1);
}).rejects.toThrowErrorMatchingSnapshot();
});
Expand Down

0 comments on commit 8b885e1

Please sign in to comment.