Skip to content

Commit

Permalink
Changed the behavior of the dictionary expansion operator (**) when…
Browse files Browse the repository at this point in the history
… used in an argument expression for a call and the operand is a TypedDict. It now takes into account the fact that a (non-closed) TypedDict can contain additional keys with `object` values. This new behavior is consistent with mypy. This addresses #8894. (#8908)
  • Loading branch information
erictraut authored Sep 5, 2024
1 parent 97f02df commit d841fb9
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 11 deletions.
27 changes: 23 additions & 4 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10988,10 +10988,10 @@ export function createTypeEvaluator(
} else if (isClassInstance(argType) && ClassType.isTypedDictClass(argType)) {
// Handle the special case where it is a TypedDict and we know which
// keys are present.
const typedDictEntries = getTypedDictMembersForClass(evaluatorInterface, argType);
const tdEntries = getTypedDictMembersForClass(evaluatorInterface, argType);
const diag = new DiagnosticAddendum();

typedDictEntries.knownItems.forEach((entry, name) => {
tdEntries.knownItems.forEach((entry, name) => {
const paramEntry = paramMap.get(name);
if (paramEntry && !paramEntry.isPositionalOnly) {
if (paramEntry.argsReceived > 0) {
Expand All @@ -11013,7 +11013,7 @@ export function createTypeEvaluator(
argCategory: ArgCategory.Simple,
typeResult: { type: entry.valueType },
},
errorNode: argList[argIndex].valueExpression || errorNode,
errorNode: argList[argIndex].valueExpression ?? errorNode,
paramName: name,
});
}
Expand All @@ -11027,7 +11027,7 @@ export function createTypeEvaluator(
argCategory: ArgCategory.Simple,
typeResult: { type: entry.valueType },
},
errorNode: argList[argIndex].valueExpression || errorNode,
errorNode: argList[argIndex].valueExpression ?? errorNode,
paramName: name,
});

Expand All @@ -11048,6 +11048,25 @@ export function createTypeEvaluator(
}
});

const extraItemsType = tdEntries.extraItems?.valueType ?? getObjectType();
if (!isNever(extraItemsType)) {
if (paramDetails.kwargsIndex !== undefined) {
const kwargsParam = paramDetails.params[paramDetails.kwargsIndex];

validateArgTypeParams.push({
paramCategory: ParamCategory.KwargsDict,
paramType: kwargsParam.declaredType,
requiresTypeVarMatching: requiresSpecialization(kwargsParam.declaredType),
argument: {
argCategory: ArgCategory.UnpackedDictionary,
typeResult: { type: extraItemsType },
},
errorNode: argList[argIndex].valueExpression ?? errorNode,
paramName: kwargsParam.param.name,
});
}
}

if (!diag.isEmpty()) {
if (!canSkipDiagnosticForNode(errorNode) && !isTypeIncomplete) {
addDiagnostic(
Expand Down
26 changes: 20 additions & 6 deletions packages/pyright-internal/src/tests/samples/call7.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def func4(arg1: int, arg2: str, **kwargs: int):
pass


def func5(arg1: int, arg2: str, **kwargs: object):
pass


td1: TD1 = {"arg1": 10, "arg2": "something"}
td2: TD2 = {"arg1": 10, "arg2": "something", "arg3": 3.4}

Expand All @@ -46,14 +50,24 @@ def func4(arg1: int, arg2: str, **kwargs: int):
func2(**td2)


# This should generate an error because the extra entries
# in the TD are of type object.
func3(**td1)

# This should generate an error because the extra entries
# in the TD are of type object.
func3(**td2)

# This should generate an error because the extra entries
# in the TD are of type object.
func4(**td1)

# This should generate an error because "arg3" cannot be matched
# due to the type of the **kwargs parameter.
func5(**td1)
func5(**td2)

# This should generate two errors because "arg3" cannot be matched
# due to the type of the **kwargs parameter. Also, the extra entries
# in the TD are of type object.
func4(**td2)


Expand All @@ -62,10 +76,10 @@ class Options(TypedDict, total=False):
opt2: str


def func5(code: str | None = None, **options: Unpack[Options]):
def func6(code: str | None = None, **options: Unpack[Options]):
pass


func5(**{})
func5(**{"opt1": True})
func5(**{"opt2": "hi"})
func6(**{})
func6(**{"opt1": True})
func6(**{"opt2": "hi"})
38 changes: 38 additions & 0 deletions packages/pyright-internal/src/tests/samples/typedDictClosed9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This sample tests the handling of calls with unpacked TypedDicts.


from typing_extensions import TypedDict # pyright: ignore[reportMissingModuleSource]


class ClosedTD1(TypedDict, closed=True):
arg1: str


class IntDict1(TypedDict, closed=True):
arg1: str
__extra_items__: int


td1 = ClosedTD1(arg1="hello")
td2 = IntDict1(arg1="hello", arg2=3)


def func1(arg1: str):
pass


func1(**td1)

# This should arguably generate an error because there
# could be extra items, but we'll match mypy's behavior here.
func1(**td2)


def func2(arg1: str, **kwargs: str):
pass


func2(**td1)

# This should result in an error.
func2(**td2)
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator1.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ test('Call6', () => {
test('Call7', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['call7.py']);

TestUtils.validateResults(analysisResults, 4);
TestUtils.validateResults(analysisResults, 8);
});

test('Call8', () => {
Expand Down
8 changes: 8 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator5.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,14 @@ test('TypedDictClosed8', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('TypedDictClosed9', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.enableExperimentalFeatures = true;

const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDictClosed9.py'], configOptions);
TestUtils.validateResults(analysisResults, 1);
});

test('DataclassTransform1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassTransform1.py']);

Expand Down

0 comments on commit d841fb9

Please sign in to comment.