diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index f9440c4ff815..b183f9f94258 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -48,6 +48,7 @@ import { ClassType, combineTypes, FunctionType, + isAnyOrUnknown, isClass, isClassInstance, isFunction, @@ -59,6 +60,7 @@ import { isTypeSame, isTypeVar, isTypeVarTuple, + isUnion, maxTypeRecursionCount, NeverType, OverloadedType, @@ -1772,17 +1774,24 @@ export function getCodeFlowEngine( const callTypeResult = evaluator.getTypeOfExpression(node.d.leftExpr, EvalFlags.CallBaseDefaults); const callType = callTypeResult.type; - doForEachSubtype(callType, (callSubtype) => { + const callSubtypes = isUnion(callType) ? callType.priv.subtypes : [callType]; + for (let callSubtype of callSubtypes) { // Track the number of subtypes we've examined. subtypeCount++; + // Any or Unknown can never establish a guaranteed NoReturn result, so any + // union containing either is immediately a negative answer. + if (isAnyOrUnknown(callSubtype)) { + return false; + } + if (isInstantiableClass(callSubtype)) { // Does the class have a custom metaclass that implements a `__call__` method? // If so, it will be called instead of `__init__` or `__new__`. We'll assume // in this case that the __call__ method is not a NoReturn type. const metaclassCallResult = getBoundCallMethod(evaluator, node, callSubtype); if (metaclassCallResult) { - return; + continue; } const newMethodResult = getBoundNewMethod(evaluator, node, callSubtype); @@ -1839,7 +1848,7 @@ export function getCodeFlowEngine( } } } - }); + } // The call is considered NoReturn if all subtypes evaluate to NoReturn. const callIsNoReturn = subtypeCount > 0 && noReturnTypeCount === subtypeCount; diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 2b93dbb9618c..46e485a4c071 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -335,12 +335,7 @@ export function isTypeVarSame(type1: TypeVarType, type2: Type) { return false; } - let isCompatible = true; - doForEachSubtype(type2, (subtype) => { - if (!isCompatible) { - return; - } - + return allSubtypes(type2, (subtype) => { if (!isTypeSame(type1, subtype)) { const conditions = getTypeCondition(subtype); @@ -348,12 +343,12 @@ export function isTypeVarSame(type1: TypeVarType, type2: Type) { !conditions || !conditions.some((condition) => condition.typeVar.priv.nameWithScope === type1.priv.nameWithScope) ) { - isCompatible = false; + return false; } } - }); - return isCompatible; + return true; + }); } export function makeInferenceContext( @@ -796,11 +791,11 @@ export function someSubtypes(type: Type, callback: (type: Type) => boolean): boo export function allSubtypes(type: Type, callback: (type: Type) => boolean): boolean { if (isUnion(type)) { return type.priv.subtypes.every((subtype) => { - callback(subtype); + return callback(subtype); }); - } else { - return callback(type); } + + return callback(type); } export function doForEachSignature( @@ -869,23 +864,17 @@ export function isUnionableType(subtypes: Type[]): boolean { } export function derivesFromAnyOrUnknown(type: Type): boolean { - let anyOrUnknown = false; + return someSubtypes(type, (subtype) => { + if (isAnyOrUnknown(subtype)) { + return true; + } - doForEachSubtype(type, (subtype) => { - if (isAnyOrUnknown(type)) { - anyOrUnknown = true; - } else if (isInstantiableClass(subtype)) { - if (ClassType.derivesFromAnyOrUnknown(subtype)) { - anyOrUnknown = true; - } - } else if (isClassInstance(subtype)) { - if (ClassType.derivesFromAnyOrUnknown(subtype)) { - anyOrUnknown = true; - } + if (isInstantiableClass(subtype) || isClassInstance(subtype)) { + return ClassType.derivesFromAnyOrUnknown(subtype); } - }); - return anyOrUnknown; + return false; + }); } export function getFullNameOfType(type: Type): string | undefined { @@ -1301,18 +1290,19 @@ export function getLiteralTypeClassName(type: Type): string | undefined { if (isUnion(type)) { let className: string | undefined; - let foundMismatch = false; - doForEachSubtype(type, (subtype) => { + for (const subtype of type.priv.subtypes) { const subtypeLiteralTypeName = getLiteralTypeClassName(subtype); if (!subtypeLiteralTypeName) { - foundMismatch = true; + return undefined; } else if (!className) { className = subtypeLiteralTypeName; + } else if (className !== subtypeLiteralTypeName) { + return undefined; } - }); + } - return foundMismatch ? undefined : className; + return className; } return undefined; @@ -2744,9 +2734,8 @@ export function combineSameSizedTuples(type: Type, tupleType: Type | undefined): } let tupleEntries: Type[][] | undefined; - let isValid = true; - - doForEachSubtype(type, (subtype) => { + const subtypes = isUnion(type) ? type.priv.subtypes : [type]; + for (const subtype of subtypes) { if (isClassInstance(subtype)) { let tupleClass: ClassType | undefined; if (isClass(subtype) && isTupleClass(subtype) && !isUnboundedTupleClass(subtype)) { @@ -2768,20 +2757,20 @@ export function combineSameSizedTuples(type: Type, tupleType: Type | undefined): tupleEntries![index].push(entry.type); }); } else { - isValid = false; + return type; } } else { tupleEntries = tupleClass.priv.tupleTypeArgs.map((entry) => [entry.type]); } } else { - isValid = false; + return type; } } else { - isValid = false; + return type; } - }); + } - if (!isValid || !tupleEntries) { + if (!tupleEntries) { return type; } diff --git a/packages/pyright-internal/src/tests/samples/unreachable1.py b/packages/pyright-internal/src/tests/samples/unreachable1.py index 43d94896ff1c..3f43e364336a 100644 --- a/packages/pyright-internal/src/tests/samples/unreachable1.py +++ b/packages/pyright-internal/src/tests/samples/unreachable1.py @@ -3,7 +3,7 @@ import os import sys from abc import abstractmethod -from typing import NoReturn +from typing import Any, Callable, NoReturn def func1(): @@ -92,6 +92,22 @@ def func9(): return 3 +def func9_1(noreturn_func: Callable[[], NoReturn], unknown_func, flag: bool): + callback = noreturn_func if flag else unknown_func + callback() + + # This should not be marked unreachable because the call target includes Unknown. + return 3 + + +def func9_2(noreturn_func: Callable[[], NoReturn], any_func: Any, flag: bool): + callback = noreturn_func if flag else any_func + callback() + + # This should not be marked unreachable because the call target includes Any. + return 3 + + def func10(): e = OSError() a1 = os.name == "nt" and None == e.errno diff --git a/packages/pyright-internal/src/tests/typeUtils.test.ts b/packages/pyright-internal/src/tests/typeUtils.test.ts new file mode 100644 index 000000000000..ce459fc9b370 --- /dev/null +++ b/packages/pyright-internal/src/tests/typeUtils.test.ts @@ -0,0 +1,172 @@ +/* + * typeUtils.test.ts + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + * + * Unit tests for typeUtils module. + */ + +import * as assert from 'assert'; + +import { + allSubtypes, + combineSameSizedTuples, + derivesFromAnyOrUnknown, + getLiteralTypeClassName, + someSubtypes, +} from '../analyzer/typeUtils'; +import { + AnyType, + ClassType, + ClassTypeFlags, + isClassInstance, + Type, + UnionableType, + UnionType, + UnknownType, +} from '../analyzer/types'; +import { Uri } from '../common/uri/uri'; + +test('AllSubtypes', () => { + const unionType = createUnion(createClassType('A'), createClassType('B'), createClassType('C')); + const visitedSubtypes: Type[] = []; + + const result = allSubtypes(unionType, (subtype) => { + visitedSubtypes.push(subtype); + return visitedSubtypes.length < 2; + }); + + assert.strictEqual(result, false); + assert.strictEqual(visitedSubtypes.length, 2); + + assert.strictEqual( + allSubtypes(unionType, () => { + return true; + }), + true + ); + + const singleType = createClassType('D'); + assert.strictEqual( + allSubtypes(singleType, (subtype) => { + assert.strictEqual(subtype, singleType); + return true; + }), + true + ); +}); + +test('SomeSubtypes', () => { + const unionType = createUnion(createClassType('A'), createClassType('B'), createClassType('C')); + const visitedSubtypes: Type[] = []; + + const result = someSubtypes(unionType, (subtype) => { + visitedSubtypes.push(subtype); + return visitedSubtypes.length === 2; + }); + + assert.strictEqual(result, true); + assert.strictEqual(visitedSubtypes.length, 2); + + assert.strictEqual( + someSubtypes(unionType, () => { + return false; + }), + false + ); + + const singleType = createClassType('D'); + assert.strictEqual( + someSubtypes(singleType, (subtype) => { + assert.strictEqual(subtype, singleType); + return false; + }), + false + ); +}); + +test('DerivesFromAnyOrUnknownUnion', () => { + const classType = createClassType('A'); + + assert.strictEqual(derivesFromAnyOrUnknown(createUnion(classType, UnknownType.create())), true); + assert.strictEqual(derivesFromAnyOrUnknown(createUnion(classType, AnyType.create())), true); + assert.strictEqual(derivesFromAnyOrUnknown(createUnion(classType, createClassType('B'))), false); +}); + +test('GetLiteralTypeClassName', () => { + const intLiteral1 = createLiteralInstance('int', 1); + const intLiteral2 = createLiteralInstance('int', 2); + const strLiteral = createLiteralInstance('str', ''); + const nonLiteralInt = ClassType.cloneAsInstance(createClassType('int', ClassTypeFlags.BuiltIn)); + + assert.strictEqual(getLiteralTypeClassName(intLiteral1), 'int'); + assert.strictEqual(getLiteralTypeClassName(createUnion(intLiteral1, intLiteral2)), 'int'); + assert.strictEqual(getLiteralTypeClassName(createUnion(intLiteral1, strLiteral)), undefined); + assert.strictEqual(getLiteralTypeClassName(createUnion(intLiteral1, nonLiteralInt)), undefined); +}); + +test('CombineSameSizedTuples', () => { + const tupleClass = createClassType('tuple', ClassTypeFlags.BuiltIn); + const intType = ClassType.cloneAsInstance(createClassType('int', ClassTypeFlags.BuiltIn)); + const strType = ClassType.cloneAsInstance(createClassType('str', ClassTypeFlags.BuiltIn)); + const boolType = ClassType.cloneAsInstance(createClassType('bool', ClassTypeFlags.BuiltIn)); + + const tuple1 = createTupleInstance(tupleClass, [intType, strType]); + const tuple2 = createTupleInstance(tupleClass, [strType, boolType]); + const tupleUnion = createUnion(tuple1, tuple2); + + const combinedTuple = combineSameSizedTuples(tupleUnion, tupleClass); + assert.notStrictEqual(combinedTuple, tupleUnion); + assert.strictEqual(isClassInstance(combinedTuple), true); + assert.strictEqual((combinedTuple as ClassType).priv.tupleTypeArgs?.length, 2); + + const mismatchedTuple = createTupleInstance(tupleClass, [intType]); + const mismatchedTupleUnion = createUnion(tuple1, mismatchedTuple); + assert.strictEqual(combineSameSizedTuples(mismatchedTupleUnion, tupleClass), mismatchedTupleUnion); + + const nonTupleUnion = createUnion(tuple1, intType); + assert.strictEqual(combineSameSizedTuples(nonTupleUnion, tupleClass), nonTupleUnion); +}); + +function createLiteralInstance(name: string, literalValue: string | number | boolean) { + return ClassType.cloneAsInstance( + ClassType.cloneWithLiteral(createClassType(name, ClassTypeFlags.BuiltIn), literalValue) + ); +} + +function createTupleInstance(tupleClass: ClassType, entries: UnionableType[]) { + return ClassType.cloneAsInstance( + ClassType.specialize( + tupleClass, + [createUnion(...entries)], + /* isTypeArgExplicit */ true, + /* includeSubclasses */ false, + entries.map((type) => { + return { type, isUnbounded: false }; + }) + ) + ); +} + +function createClassType(name: string, flags = ClassTypeFlags.None) { + const classType = ClassType.createInstantiable( + name, + name, + '', + Uri.empty(), + flags, + 0, + /* declaredMetaclass*/ undefined, + /* effectiveMetaclass */ undefined + ); + classType.shared.mro.push(classType); + return classType; +} + +function createUnion(...subtypes: UnionableType[]) { + const unionType = UnionType.create(); + subtypes.forEach((subtype) => { + UnionType.addType(unionType, subtype); + }); + return unionType; +}