Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: type inference for lambdas and their parameters #1304

Merged
merged 3 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ export class SafeDsTypeComputer {
private readonly typeChecker: SafeDsTypeChecker;

/**
* Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters,
* they must only access the type of the containing lambda, if they are not contained in this set themselves.
* Otherwise, this would cause endless recursion.
* Contains all calls for which we currently compute substitutions. This prevents endless recursion, since the
* substitutions of a call depend on the inferred types of their arguments, which may be lambdas. The inferred type
* of a lambda in turn depends on the substitutions of the call it is passed to.
*/
private readonly incompleteLambdaParameters = new Set<SdsParameter>();
private readonly incompleteCalls = new Set<SdsAbstractCall>();
private readonly nodeTypeCache: WorkspaceCache<string, Type>;

constructor(services: SafeDsServices) {
Expand Down Expand Up @@ -301,18 +301,21 @@ export class SafeDsTypeComputer {

// Lambda passed as argument
if (isSdsArgument(containerOfLambda)) {
// Lookup parameter type in lambda unless the lambda is being computed. These contain the correct
// substitutions for type parameters.
if (!this.incompleteLambdaParameters.has(node)) {
return this.computeType(containingCallable);
}

const parameter = this.nodeMapper.argumentToParameter(containerOfLambda);
if (!parameter) {
return UnknownType;
}

return this.computeType(parameter);
let result = this.computeType(parameter);

// Substitute type parameters
const call = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (call) {
const substitutions = this.computeSubstitutionsForCall(call, containerOfLambda.$containerIndex);
result = result.substituteTypeParameters(substitutions);
}

return result;
}

// Lambda passed as default value
Expand Down Expand Up @@ -569,29 +572,16 @@ export class SafeDsTypeComputer {
}

private computeTypeOfLambda(node: SdsLambda): Type {
// Remember lambda parameters
const parameters = getParameters(node);
parameters.forEach((it) => {
this.incompleteLambdaParameters.add(it);
});

const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it)));
const resultEntries = this.buildLambdaResultEntries(node);

const unsubstitutedType = this.factory.createCallableType(
return this.factory.createCallableType(
node,
undefined,
this.factory.createNamedTupleType(...parameterEntries),
this.factory.createNamedTupleType(...resultEntries),
);
const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType);

// Forget lambda parameters
parameters.forEach((it) => {
this.incompleteLambdaParameters.delete(it);
});

return unsubstitutedType.substituteTypeParameters(substitutions);
}

private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry<SdsAbstractResult>[] {
Expand Down Expand Up @@ -843,23 +833,32 @@ export class SafeDsTypeComputer {
/**
* Computes substitutions for the type parameters of a callable in the context of a call.
*
* @param node The call to compute substitutions for.
* @param node
* The call to compute substitutions for.
* @param argumentEndIndex
* The index of the first argument that should not be considered for the computation. If not specified, all
* arguments are considered.
*
* @returns The computed substitutions for the type parameters of the callable.
*/
computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions {
return this.doComputeSubstitutionsForCall(node);
}

private doComputeSubstitutionsForCall(
computeSubstitutionsForCall(
node: SdsAbstractCall,
precomputedArgumentTypes?: Map<AstNode | undefined, Type>,
argumentEndIndex: number | undefined = undefined,
): TypeParameterSubstitutions {
// Compute substitutions for member access
const substitutionsFromReceiver =
isSdsCall(node) && isSdsMemberAccess(node.receiver)
? this.computeSubstitutionsForMemberAccess(node.receiver)
: NO_SUBSTITUTIONS;

// Check if the call is already being computed
if (this.incompleteCalls.has(node)) {
return substitutionsFromReceiver;
}

// Remember call
this.incompleteCalls.add(node);

// Compute substitutions for arguments
const callable = this.nodeMapper.callToCallable(node);
const typeParameters = getTypeParameters(callable);
Expand All @@ -868,24 +867,22 @@ export class SafeDsTypeComputer {
}

const parameters = getParameters(callable);
const args = getArguments(node);
const args = getArguments(node).slice(0, argumentEndIndex);

const parametersToArguments = this.nodeMapper.parametersToArguments(parameters, args);
const parameterTypesToArgumentTypes: [Type, Type][] = parameters.map((parameter) => {
const argument = parametersToArguments.get(parameter);
return [
this.computeType(parameter.type),
// Use precomputed argument types (lambdas) if available. This prevents infinite recursion.
precomputedArgumentTypes?.get(argument?.value) ??
this.computeType(argument?.value ?? parameter.defaultValue),
];
return [this.computeType(parameter.type), this.computeType(argument?.value ?? parameter.defaultValue)];
});

const substitutionsFromArguments = this.computeSubstitutionsForArguments(
typeParameters,
parameterTypesToArgumentTypes,
);

// Forget call
this.incompleteCalls.delete(node);

return new Map([...substitutionsFromReceiver, ...substitutionsFromArguments]);
}

Expand Down Expand Up @@ -918,22 +915,6 @@ export class SafeDsTypeComputer {
return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes);
}

private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions {
const containerOfLambda = node.$container;
if (!isSdsArgument(containerOfLambda)) {
return NO_SUBSTITUTIONS;
}

const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall);
if (!containingCall) {
/* c8 ignore next 2 */
return NO_SUBSTITUTIONS;
}

const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]);
return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes);
}

private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions {
const receiverType = this.computeType(node.receiver);
if (receiverType instanceof ClassType) {
Expand Down
2 changes: 1 addition & 1 deletion packages/safe-ds-lang/src/language/validation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export const argumentTypesMustMatchParameterTypes = (services: SafeDsServices) =
return;
}

const argumentType = typeComputer.computeType(argument).substituteTypeParameters(substitutions);
const argumentType = typeComputer.computeType(argument);
const parameterType = typeComputer.computeType(parameter).substituteTypeParameters(substitutions);

if (!typeChecker.isSubtypeOf(argumentType, parameterType, { ignoreParameterNames: true })) {
Expand Down
4 changes: 2 additions & 2 deletions packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ describe('getNodeOfType', async () => {

it('should throw if no node is found', async () => {
const code = '';
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass);
}).rejects.toThrowErrorMatchingSnapshot();
});

it('should throw if not enough nodes are found', async () => {
const code = `class C`;
expect(async () => {
await expect(async () => {
await getNodeOfType(services, code, isSdsClass, 1);
}).rejects.toThrowErrorMatchingSnapshot();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ segment mySegment() {
// $TEST$ serialization literal<1>
myFunction(1, (»p«) {});

// $TEST$ serialization literal<"">
// $TEST$ serialization Nothing
myFunction2((»p«) -> "");
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ segment mySegment() {
// $TEST$ serialization literal<1>
myFunction(1, (»p«) -> "");

// $TEST$ serialization literal<"">
// $TEST$ serialization Nothing
myFunction2((»p«) -> "");
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ class MyClass<T>(param: T) sub MySuperclass<T> {
@Pure fun myMethod(callback: (p: T) -> ())
}

@Pure fun myFunction<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction1<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction2<T>(id: (p: T) -> (r: T))
@Pure fun myFunction3<T>(producer: () -> (r: T), consumer: (p: T) -> ())

segment mySegment() {
// $TEST$ serialization (p: literal<1>) -> (r: literal<1>)
Expand All @@ -22,7 +24,19 @@ segment mySegment() {
}«);

// $TEST$ serialization (p: literal<1>) -> (r: literal<1>)
myFunction(1, »(p) {
myFunction1(1, »(p) {
yield r = p;
}«);

// $TEST$ serialization (p: Nothing) -> (r: literal<1>)
myFunction2(»(p) {
yield r = 1;
}«);

// $TEST$ serialization () -> (r: literal<1>)
// $TEST$ serialization (p: literal<1>) -> ()
myFunction3(
»() { yield r = 1; }«,
»(p) {}«,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ class MyClass<T>(param: T) sub MySuperclass<T> {
@Pure fun myMethod(callback: (p: T) -> ())
}

@Pure fun myFunction<T>(p: T, id: (p: T) -> (r: T))

segment mySegment() {
// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
MyClass(1).myMethod(»(p) -> p«);

// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
MyClass(1).myInheritedMethod(»(p) -> p«);
@Pure fun myFunction1<T>(p: T, id: (p: T) -> (r: T))
@Pure fun myFunction2<T>(id: (p: T) -> (r: T))
@Pure fun myFunction3<T>(producer: () -> (r: T), consumer: (p: T) -> ())

segment mySegment() {
// $TEST$ serialization () -> (result: literal<1>)
// $TEST$ serialization (p: literal<1>) -> (result: literal<1>)
myFunction(1, »(p) -> p«);
myFunction3(
»() -> 1«,
»(p) -> 1«,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ segment mySegment(
// $TEST$ no error r"Expected type .* but got .*\."
f(»(p) -> p«);

// $TEST$ no error r"Expected type .* but got .*\."
// $TEST$ error "Expected type '(p: literal<1>) -> (r: literal<1>)' but got '(p: Nothing) -> (result: literal<1>)'."
f(»(p) -> 1«);

// $TEST$ no error r"Expected type .* but got .*\."
Expand Down
Loading