diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e7c2cba3fc55..fc60cf1befb1 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2013,7 +2013,12 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - if isinstance(ret_type, TypeVarType): + proper_ret = get_proper_type(ret_type) + if ( + isinstance(proper_ret, TypeVarType) + or isinstance(proper_ret, UnionType) + and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) + ): # Another special case: the return type is a type variable. If it's unrestricted, # we could infer a too general type for the type variable if we use context, # and this could result in confusing and spurious type errors elsewhere. diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 17ae6d9934b7..0c068f551e11 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1495,3 +1495,30 @@ def g(b: Optional[str]) -> None: z: Callable[[], str] = lambda: reveal_type(b) # N: Revealed type is "builtins.str" f2(lambda: reveal_type(b)) # N: Revealed type is "builtins.str" lambda: reveal_type(b) # N: Revealed type is "builtins.str" + +[case testInferenceContextReturningTypeVarUnion] +from collections.abc import Callable, Iterable +from typing import TypeVar, Union + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + +def mymin( + iterable: Iterable[_T1], /, *, key: Callable[[_T1], int], default: _T2 +) -> Union[_T1, _T2]: ... + +def check(paths: Iterable[str], key: Callable[[str], int]) -> Union[str, None]: + return mymin(paths, key=key, default=None) +[builtins fixtures/tuple.pyi] + +[case testInferenceContextMappingGet-xfail] +from collections.abc import Mapping +from typing import TypeVar, Union + +_T1 = TypeVar("_T1") + +def check(mapping: Mapping[str, _T1]) -> None: + fail1 = mapping.get("", "") + fail2: Union[_T1, str] = mapping.get("", "") +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi]