diff --git a/packages/safe-ds-lang/src/language/builtins/safe-ds-annotations.ts b/packages/safe-ds-lang/src/language/builtins/safe-ds-annotations.ts index 195b8fd62..6bc20b10f 100644 --- a/packages/safe-ds-lang/src/language/builtins/safe-ds-annotations.ts +++ b/packages/safe-ds-lang/src/language/builtins/safe-ds-annotations.ts @@ -10,13 +10,7 @@ import { SdsModule, SdsParameter, } from '../generated/ast.js'; -import { - findFirstAnnotationCallOf, - getArguments, - getEnumVariants, - getParameters, - hasAnnotationCallOf, -} from '../helpers/nodeProperties.js'; +import { findFirstAnnotationCallOf, getEnumVariants, hasAnnotationCallOf } from '../helpers/nodeProperties.js'; import { SafeDsNodeMapper } from '../helpers/safe-ds-node-mapper.js'; import { EvaluatedEnumVariant, @@ -79,7 +73,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { streamImpurityReasons(node: SdsFunction | undefined): Stream { // If allReasons are specified, but we could not evaluate them to a list, no reasons apply - const value = this.getArgumentValue(node, this.Impure, 'allReasons'); + const value = this.getParameterValue(node, this.Impure, 'allReasons'); if (!(value instanceof EvaluatedList)) { return EMPTY_STREAM; } @@ -107,7 +101,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { } getPythonCall(node: SdsFunction | undefined): string | undefined { - const value = this.getArgumentValue(node, this.PythonCall, 'callSpecification'); + const value = this.getParameterValue(node, this.PythonCall, 'callSpecification'); if (value instanceof StringConstant) { return value.value; } else { @@ -120,7 +114,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { } getPythonModule(node: SdsModule | undefined): string | undefined { - const value = this.getArgumentValue(node, this.PythonModule, 'qualifiedName'); + const value = this.getParameterValue(node, this.PythonModule, 'qualifiedName'); if (value instanceof StringConstant) { return value.value; } else { @@ -133,7 +127,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { } getPythonName(node: SdsAnnotatedObject | undefined): string | undefined { - const value = this.getArgumentValue(node, this.PythonName, 'name'); + const value = this.getParameterValue(node, this.PythonName, 'name'); if (value instanceof StringConstant) { return value.value; } else { @@ -160,7 +154,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { } // If targets are specified, but we could not evaluate them to a list, no target is valid - const value = this.getArgumentValue(node, this.Target, 'targets'); + const value = this.getParameterValue(node, this.Target, 'targets'); if (!(value instanceof EvaluatedList)) { return EMPTY_STREAM; } @@ -187,7 +181,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { * Finds the first call of the given annotation on the given node and returns the value that is assigned to the * parameter with the given name. */ - private getArgumentValue( + private getParameterValue( node: SdsAnnotatedObject | undefined, annotation: SdsAnnotation | undefined, parameterName: string, @@ -197,16 +191,7 @@ export class SafeDsAnnotations extends SafeDsModuleMembers { return UnknownEvaluatedNode; } - // Parameter is set explicitly - const argument = getArguments(annotationCall).find( - (it) => this.nodeMapper.argumentToParameter(it)?.name === parameterName, - ); - if (argument) { - return this.partialEvaluator.evaluate(argument.value); - } - - // Parameter is not set explicitly, so we use the default value - const parameter = getParameters(annotation).find((it) => it.name === parameterName); - return this.partialEvaluator.evaluate(parameter?.defaultValue); + const parameterValue = this.nodeMapper.callToParameterValue(annotationCall, parameterName); + return this.partialEvaluator.evaluate(parameterValue); } } diff --git a/packages/safe-ds-lang/src/language/helpers/safe-ds-node-mapper.ts b/packages/safe-ds-lang/src/language/helpers/safe-ds-node-mapper.ts index f69e11cd7..c65e72e8a 100644 --- a/packages/safe-ds-lang/src/language/helpers/safe-ds-node-mapper.ts +++ b/packages/safe-ds-lang/src/language/helpers/safe-ds-node-mapper.ts @@ -9,6 +9,7 @@ import { isSdsClass, isSdsEnumVariant, isSdsNamedType, + isSdsParameter, isSdsReference, isSdsSegment, isSdsType, @@ -48,7 +49,7 @@ export class SafeDsNodeMapper { } /** - * Returns the parameter that the argument is assigned to. If there is no matching parameter, returns undefined. + * Returns the parameter that the argument is assigned to. If there is no matching parameter, returns `undefined`. */ argumentToParameter(node: SdsArgument | undefined): SdsParameter | undefined { if (!node) { @@ -126,7 +127,7 @@ export class SafeDsNodeMapper { } /** - * Returns the callable that is called by the given call. If no callable can be found, returns undefined. + * Returns the callable that is called by the given call. If no callable can be found, returns `undefined`. */ callToCallable(node: SdsAbstractCall | undefined): SdsCallable | undefined { if (!node) { @@ -150,6 +151,46 @@ export class SafeDsNodeMapper { return undefined; } + /** + * Returns the value that is assigned to the given parameter in the given call. This can be either the argument + * value, or the parameter's default value if no argument is provided. If no value can be found, returns + * `undefined`. + * + * @param call The call whose parameter value to return. + * @param parameter The parameter whose value to return. Can be either a parameter itself or its name. + */ + callToParameterValue( + call: SdsAbstractCall | undefined, + parameter: SdsParameter | string | undefined, + ): SdsExpression | undefined { + if (!call || !parameter) { + return undefined; + } + + // Parameter is set explicitly + const argument = getArguments(call).find((it) => { + if (isSdsParameter(parameter)) { + return this.argumentToParameter(it) === parameter; + } else { + return this.argumentToParameter(it)?.name === parameter; + } + }); + if (argument) { + return argument.value; + } + + // Parameter is not set but might have a default value + // We must ensure the parameter belongs to the called callable, so we cannot directly get the defaultValue + const callable = this.callToCallable(call); + return getParameters(callable).find((it) => { + if (isSdsParameter(parameter)) { + return it === parameter; + } else { + return it.name === parameter; + } + })?.defaultValue; + } + /** * Returns all references that target the given parameter. */ @@ -210,7 +251,7 @@ export class SafeDsNodeMapper { /** * Returns the type parameter that the type argument is assigned to. If there is no matching type parameter, returns - * undefined. + * `undefined`. */ typeArgumentToTypeParameter(node: SdsTypeArgument | undefined): SdsTypeParameter | undefined { if (!node) { diff --git a/packages/safe-ds-lang/tests/language/helpers/safe-ds-node-mapper/callToParameterValue.test.ts b/packages/safe-ds-lang/tests/language/helpers/safe-ds-node-mapper/callToParameterValue.test.ts new file mode 100644 index 000000000..736cc8fd5 --- /dev/null +++ b/packages/safe-ds-lang/tests/language/helpers/safe-ds-node-mapper/callToParameterValue.test.ts @@ -0,0 +1,169 @@ +import { EmptyFileSystem } from 'langium'; +import { describe, expect, it } from 'vitest'; +import { + isSdsModule, + SdsAbstractCall, + SdsFunction, + SdsParameter, + SdsPipeline, +} from '../../../../src/language/generated/ast.js'; +import { createSafeDsServices } from '../../../../src/language/index.js'; +import { Constant, IntConstant } from '../../../../src/language/partialEvaluation/model.js'; +import { getNodeOfType } from '../../../helpers/nodeFinder.js'; +import { getModuleMembers, getParameters } from '../../../../src/language/helpers/nodeProperties.js'; + +const services = createSafeDsServices(EmptyFileSystem).SafeDs; +const callGraphComputer = services.flow.CallGraphComputer; +const nodeMapper = services.helpers.NodeMapper; +const partialEvaluator = services.evaluation.PartialEvaluator; + +const code = ` + fun myFunction(p1: String, p2: String = 0) + + pipeline myPipeline { + myFunction(1, 2); + myFunction(); + unresolved(); + } +`; +const module = await getNodeOfType(services, code, isSdsModule); +const myFunction = getModuleMembers(module)[0] as SdsFunction; +const p1 = getParameters(myFunction)[0]!; +const p2 = getParameters(myFunction)[1]!; +const myPipeline = module?.members[1] as SdsPipeline; +const call1 = callGraphComputer.getCalls(myPipeline)[0]!; +const call2 = callGraphComputer.getCalls(myPipeline)[1]!; +const call3 = callGraphComputer.getCalls(myPipeline)[2]!; + +describe('SafeDsNodeMapper', () => { + const testCases: CallToParameterValueTest[] = [ + { + testName: 'undefined call, undefined parameter', + call: undefined, + parameter: undefined, + expectedResult: undefined, + }, + { + testName: 'undefined call, defined parameter', + call: undefined, + parameter: p1, + expectedResult: undefined, + }, + { + testName: 'defined call, undefined parameter', + call: call1, + parameter: undefined, + expectedResult: undefined, + }, + { + testName: 'parameter is object, required parameter, value provided', + call: call1, + parameter: p1, + expectedResult: new IntConstant(1n), + }, + { + testName: 'parameter is object, optional parameter, value provided', + call: call1, + parameter: p2, + expectedResult: new IntConstant(2n), + }, + { + testName: 'parameter is object, required parameter, no value provided', + call: call2, + parameter: p1, + expectedResult: undefined, + }, + { + testName: 'parameter is object, optional parameter, no value provided', + call: call2, + parameter: p2, + expectedResult: new IntConstant(0n), + }, + { + testName: 'parameter is string, required parameter, value provided', + call: call1, + parameter: 'p1', + expectedResult: new IntConstant(1n), + }, + { + testName: 'parameter is string, optional parameter, value provided', + call: call1, + parameter: 'p2', + expectedResult: new IntConstant(2n), + }, + { + testName: 'parameter is string, required parameter, no value provided', + call: call2, + parameter: 'p1', + expectedResult: undefined, + }, + { + testName: 'parameter is string, optional parameter, no value provided', + call: call2, + parameter: 'p2', + expectedResult: new IntConstant(0n), + }, + { + testName: 'parameter is object, required parameter, unresolved callable', + call: call3, + parameter: p1, + expectedResult: undefined, + }, + { + testName: 'parameter is object, optional parameter, unresolved callable', + call: call3, + parameter: p2, + expectedResult: undefined, + }, + { + testName: 'parameter is string, required parameter, unresolved callable', + call: call3, + parameter: 'p1', + expectedResult: undefined, + }, + { + testName: 'parameter is string, optional parameter, unresolved callable', + call: call3, + parameter: 'p2', + expectedResult: undefined, + }, + ]; + + describe.each(testCases)('callToParameterValue', ({ testName, call, parameter, expectedResult }) => { + it(testName, () => { + const parameterValue = nodeMapper.callToParameterValue(call, parameter); + if (expectedResult === undefined) { + expect(parameterValue).toBeUndefined(); + return; + } + + const evaluatedParameterValue = partialEvaluator.evaluate(parameterValue); + expect(evaluatedParameterValue).toStrictEqual(expectedResult); + }); + }); +}); + +/** + * A test case for {@link SafeDsNodeMapper.callToParameterValue}. + */ +interface CallToParameterValueTest { + /** + * A short description of the test case. + */ + testName: string; + + /** + * The abstract call to test. + */ + call: SdsAbstractCall | undefined; + + /** + * The parameter to test. + */ + parameter: SdsParameter | string | undefined; + + /** + * The expected result. + */ + expectedResult: Constant | undefined; +}