diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 5784af2b194a..06f948675455 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -1456,14 +1456,11 @@ export class Binder extends ParseTreeWalker { this._targetFunctionDeclaration.raiseStatements.push(node); } - if (node.d.typeExpression) { - this.walk(node.d.typeExpression); - } - if (node.d.valueExpression) { - this.walk(node.d.valueExpression); + if (node.d.expr) { + this.walk(node.d.expr); } - if (node.d.tracebackExpression) { - this.walk(node.d.tracebackExpression); + if (node.d.fromExpr) { + this.walk(node.d.fromExpr); } this._finallyTargets.forEach((target) => { diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 998f9eee570f..dc2201e9c231 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -1093,46 +1093,12 @@ export class Checker extends ParseTreeWalker { } override visitRaise(node: RaiseNode): boolean { - this._evaluator.verifyRaiseExceptionType(node); - - if (node.d.valueExpression) { - const baseExceptionType = this._evaluator.getBuiltInType(node, 'BaseException') as ClassType; - const exceptionType = this._evaluator.getType(node.d.valueExpression); - - // Validate that the argument of "raise" is an exception object or None. - if (exceptionType && baseExceptionType && isInstantiableClass(baseExceptionType)) { - const diagAddendum = new DiagnosticAddendum(); - - doForEachSubtype(exceptionType, (subtype) => { - subtype = this._evaluator.makeTopLevelTypeVarsConcrete(subtype); - - if (!isAnyOrUnknown(subtype) && !isNoneInstance(subtype)) { - if (isClass(subtype)) { - if (!derivesFromClassRecursive(subtype, baseExceptionType, /* ignoreUnknown */ false)) { - diagAddendum.addMessage( - LocMessage.exceptionTypeIncorrect().format({ - type: this._evaluator.printType(subtype), - }) - ); - } - } else { - diagAddendum.addMessage( - LocMessage.exceptionTypeIncorrect().format({ - type: this._evaluator.printType(subtype), - }) - ); - } - } - }); + if (node.d.expr) { + this._evaluator.verifyRaiseExceptionType(node.d.expr); + } - if (!diagAddendum.isEmpty()) { - this._evaluator.addDiagnostic( - DiagnosticRule.reportGeneralTypeIssues, - LocMessage.expectedExceptionObj() + diagAddendum.getString(), - node.d.valueExpression - ); - } - } + if (node.d.fromExpr) { + this._evaluator.verifyRaiseExceptionType(node.d.fromExpr); } return true; diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 1af635076fca..d34573578b80 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -1862,9 +1862,9 @@ export function getCodeFlowEngine( continue; } - if (simpleStatement.nodeType === ParseNodeType.Raise && simpleStatement.d.typeExpression) { + if (simpleStatement.nodeType === ParseNodeType.Raise && simpleStatement.d.expr) { // Check for a raising about 'NotImplementedError' or a subtype thereof. - const exceptionType = evaluator.getType(simpleStatement.d.typeExpression); + const exceptionType = evaluator.getType(simpleStatement.d.expr); if ( exceptionType && diff --git a/packages/pyright-internal/src/analyzer/parseTreeWalker.ts b/packages/pyright-internal/src/analyzer/parseTreeWalker.ts index 73cc7593b1f9..f5a83ed0888e 100644 --- a/packages/pyright-internal/src/analyzer/parseTreeWalker.ts +++ b/packages/pyright-internal/src/analyzer/parseTreeWalker.ts @@ -272,7 +272,7 @@ export function getChildNodes(node: ParseNode): (ParseNode | undefined)[] { return [node.d.expr]; case ParseNodeType.Raise: - return [node.d.typeExpression, node.d.valueExpression, node.d.tracebackExpression]; + return [node.d.expr, node.d.fromExpr]; case ParseNodeType.Return: return [node.d.expr]; diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 871ac4de1987..965dda1bc589 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -68,7 +68,6 @@ import { ParameterNode, ParseNode, ParseNodeType, - RaiseNode, SetNode, SliceNode, StringListNode, @@ -4343,87 +4342,80 @@ export function createTypeEvaluator( } } - function verifyRaiseExceptionType(node: RaiseNode) { + function verifyRaiseExceptionType(node: ExpressionNode) { const baseExceptionType = getBuiltInType(node, 'BaseException'); + const exceptionType = getTypeOfExpression(node).type; - if (node.d.typeExpression) { - const exceptionType = getTypeOfExpression(node.d.typeExpression).type; + // Validate that the argument of "raise" is an exception object or class. + // If it is a class, validate that the class's constructor accepts zero + // arguments. + if (exceptionType && baseExceptionType && isInstantiableClass(baseExceptionType)) { + const diag = new DiagnosticAddendum(); - // Validate that the argument of "raise" is an exception object or class. - // If it is a class, validate that the class's constructor accepts zero - // arguments. - if (exceptionType && baseExceptionType && isInstantiableClass(baseExceptionType)) { - const diagAddendum = new DiagnosticAddendum(); + doForEachSubtype(exceptionType, (subtype) => { + const concreteSubtype = makeTopLevelTypeVarsConcrete(subtype); - doForEachSubtype(exceptionType, (subtype) => { - const concreteSubtype = makeTopLevelTypeVarsConcrete(subtype); + if (isAnyOrUnknown(concreteSubtype) || isNever(concreteSubtype) || isNoneInstance(concreteSubtype)) { + return; + } - if (!isAnyOrUnknown(concreteSubtype)) { - if (isInstantiableClass(concreteSubtype) && concreteSubtype.priv.literalValue === undefined) { - if ( - !derivesFromClassRecursive( - concreteSubtype, - baseExceptionType, - /* ignoreUnknown */ false - ) - ) { - diagAddendum.addMessage( - LocMessage.exceptionTypeIncorrect().format({ - type: printType(subtype), - }) - ); - } else { - let callResult: CallResult | undefined; - suppressDiagnostics(node.d.typeExpression!, () => { - callResult = validateConstructorArgs( - evaluatorInterface, - node.d.typeExpression!, - [], - concreteSubtype, - /* skipUnknownArgCheck */ false, - /* inferenceContext */ undefined - ); - }); + if (isInstantiableClass(concreteSubtype) && concreteSubtype.priv.literalValue === undefined) { + if (!derivesFromClassRecursive(concreteSubtype, baseExceptionType, /* ignoreUnknown */ false)) { + diag.addMessage( + LocMessage.exceptionTypeIncorrect().format({ + type: printType(subtype), + }) + ); + } else { + let callResult: CallResult | undefined; + suppressDiagnostics(node, () => { + callResult = validateConstructorArgs( + evaluatorInterface, + node, + [], + concreteSubtype, + /* skipUnknownArgCheck */ false, + /* inferenceContext */ undefined + ); + }); - if (callResult && callResult.argumentErrors) { - diagAddendum.addMessage( - LocMessage.exceptionTypeNotInstantiable().format({ - type: printType(subtype), - }) - ); - } - } - } else if (isClassInstance(concreteSubtype)) { - if ( - !derivesFromClassRecursive( - ClassType.cloneAsInstantiable(concreteSubtype), - baseExceptionType, - /* ignoreUnknown */ false - ) - ) { - diagAddendum.addMessage( - LocMessage.exceptionTypeIncorrect().format({ - type: printType(subtype), - }) - ); - } - } else { - diagAddendum.addMessage( - LocMessage.exceptionTypeIncorrect().format({ + if (callResult && callResult.argumentErrors) { + diag.addMessage( + LocMessage.exceptionTypeNotInstantiable().format({ type: printType(subtype), }) ); } } - }); - - if (!diagAddendum.isEmpty()) { - addDiagnostic( - DiagnosticRule.reportGeneralTypeIssues, - LocMessage.expectedExceptionClass() + diagAddendum.getString(), - node.d.typeExpression + } else if (isClassInstance(concreteSubtype)) { + if ( + !derivesFromClassRecursive( + ClassType.cloneAsInstantiable(concreteSubtype), + baseExceptionType, + /* ignoreUnknown */ false + ) + ) { + diag.addMessage( + LocMessage.exceptionTypeIncorrect().format({ + type: printType(subtype), + }) + ); + } + } else { + diag.addMessage( + LocMessage.exceptionTypeIncorrect().format({ + type: printType(subtype), + }) ); } + }); + + if (!diag.isEmpty()) { + addDiagnostic( + DiagnosticRule.reportGeneralTypeIssues, + LocMessage.expectedExceptionClass() + diag.getString(), + node + ); } } } @@ -19242,10 +19234,10 @@ export function createTypeEvaluator( } for (const raiseStatement of functionDecl.raiseStatements) { - if (!raiseStatement.d.typeExpression || raiseStatement.d.valueExpression) { + if (!raiseStatement.d.expr || raiseStatement.d.fromExpr) { return false; } - const raiseType = getTypeOfExpression(raiseStatement.d.typeExpression).type; + const raiseType = getTypeOfExpression(raiseStatement.d.expr).type; const classType = isInstantiableClass(raiseType) ? raiseType : isClassInstance(raiseType) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts index f2479c5b6e13..2d16b2d9eadf 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts @@ -27,7 +27,6 @@ import { ParamCategory, ParameterNode, ParseNode, - RaiseNode, StringNode, } from '../parser/parseNodes'; import { AnalyzerFileInfo } from './analyzerFileInfo'; @@ -551,7 +550,7 @@ export interface TypeEvaluator { ) => Type; getExpectedType: (node: ExpressionNode) => ExpectedTypeResult | undefined; - verifyRaiseExceptionType: (node: RaiseNode) => void; + verifyRaiseExceptionType: (node: ExpressionNode) => void; verifyDeleteExpression: (node: ExpressionNode) => void; validateOverloadedArgTypes: ( errorNode: ExpressionNode, diff --git a/packages/pyright-internal/src/parser/parseNodes.ts b/packages/pyright-internal/src/parser/parseNodes.ts index 894db6faf4a7..6c5d429e07e1 100644 --- a/packages/pyright-internal/src/parser/parseNodes.ts +++ b/packages/pyright-internal/src/parser/parseNodes.ts @@ -2314,9 +2314,8 @@ export namespace ReturnNode { export interface RaiseNode extends ParseNodeBase { d: { - typeExpression?: ExpressionNode | undefined; - valueExpression?: ExpressionNode | undefined; - tracebackExpression?: ExpressionNode | undefined; + expr?: ExpressionNode | undefined; + fromExpr?: ExpressionNode | undefined; }; } diff --git a/packages/pyright-internal/src/parser/parser.ts b/packages/pyright-internal/src/parser/parser.ts index 4a5c8a636417..63f81993785d 100644 --- a/packages/pyright-internal/src/parser/parser.ts +++ b/packages/pyright-internal/src/parser/parser.ts @@ -2821,29 +2821,14 @@ export class Parser { const raiseNode = RaiseNode.create(raiseToken); if (!this._isNextTokenNeverExpression()) { - raiseNode.d.typeExpression = this._parseTestExpression(/* allowAssignmentExpression */ true); - raiseNode.d.typeExpression.parent = raiseNode; - extendRange(raiseNode, raiseNode.d.typeExpression); + raiseNode.d.expr = this._parseTestExpression(/* allowAssignmentExpression */ true); + raiseNode.d.expr.parent = raiseNode; + extendRange(raiseNode, raiseNode.d.expr); if (this._consumeTokenIfKeyword(KeywordType.From)) { - raiseNode.d.valueExpression = this._parseTestExpression(/* allowAssignmentExpression */ true); - raiseNode.d.valueExpression.parent = raiseNode; - extendRange(raiseNode, raiseNode.d.valueExpression); - } else { - if (this._consumeTokenIfType(TokenType.Comma)) { - // Handle the Python 2.x variant - raiseNode.d.valueExpression = this._parseTestExpression(/* allowAssignmentExpression */ true); - raiseNode.d.valueExpression.parent = raiseNode; - extendRange(raiseNode, raiseNode.d.valueExpression); - - if (this._consumeTokenIfType(TokenType.Comma)) { - raiseNode.d.tracebackExpression = this._parseTestExpression( - /* allowAssignmentExpression */ true - ); - raiseNode.d.tracebackExpression.parent = raiseNode; - extendRange(raiseNode, raiseNode.d.tracebackExpression); - } - } + raiseNode.d.fromExpr = this._parseTestExpression(/* allowAssignmentExpression */ true); + raiseNode.d.fromExpr.parent = raiseNode; + extendRange(raiseNode, raiseNode.d.fromExpr); } } diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts index ef526d1365d6..d7c328fd47d3 100644 --- a/packages/pyright-internal/src/tests/checker.test.ts +++ b/packages/pyright-internal/src/tests/checker.test.ts @@ -396,7 +396,7 @@ test('ParamType1', () => { test('Python2', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['python2.py']); - TestUtils.validateResults(analysisResults, 6); + TestUtils.validateResults(analysisResults, 7); }); test('InconsistentSpaceTab1', () => { diff --git a/packages/pyright-internal/src/tests/samples/python2.py b/packages/pyright-internal/src/tests/samples/python2.py index ab331c065486..6703225bd405 100644 --- a/packages/pyright-internal/src/tests/samples/python2.py +++ b/packages/pyright-internal/src/tests/samples/python2.py @@ -26,6 +26,6 @@ def foo(a, (b, c), d): pass -# This should generate an error. +# This should generate two errors. raise NameError, a > 4, a < 4