Skip to content

Commit 25f186b

Browse files
committed
stubtest: basic support for unpack kwargs
Fixes #21023
1 parent 537740b commit 25f186b

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

mypy/stubtest.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,24 @@ def from_funcitem(stub: nodes.FuncItem) -> Signature[nodes.Argument]:
941941
elif stub_arg.kind == nodes.ARG_STAR:
942942
stub_sig.varpos = stub_arg
943943
elif stub_arg.kind == nodes.ARG_STAR2:
944-
stub_sig.varkw = stub_arg
944+
if stub_arg.variable.type is not None and isinstance(
945+
(typed_dict_arg := mypy.types.get_proper_type(stub_arg.variable.type)),
946+
mypy.types.TypedDictType,
947+
):
948+
for key_name, key_type in typed_dict_arg.items.items():
949+
stub_sig.kwonly[key_name] = nodes.Argument(
950+
nodes.Var(key_name, key_type),
951+
type_annotation=key_type,
952+
initializer=(
953+
nodes.EllipsisExpr()
954+
if key_name not in typed_dict_arg.required_keys
955+
else None
956+
),
957+
kind=nodes.ARG_NAMED,
958+
pos_only=False,
959+
)
960+
else:
961+
stub_sig.varkw = stub_arg
945962
else:
946963
raise AssertionError
947964
return stub_sig

mypy/test/teststubtest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __getitem__(self, typeargs: Any) -> object: ...
6060
Final = 0
6161
Literal = 0
6262
TypedDict = 0
63+
Unpack = 0
6364
6465
class TypeVar:
6566
def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ...
@@ -765,6 +766,27 @@ def test_varargs_varkwargs(self) -> Iterator[Case]:
765766
error="k6",
766767
)
767768

769+
@collect_cases
770+
def test_kwargs_unpack_typeddict(self) -> Iterator[Case]:
771+
yield Case(
772+
stub="""
773+
from typing import TypedDict, Unpack
774+
775+
class _Args(TypedDict):
776+
a: int
777+
b: int
778+
779+
def f1(**kwargs: Unpack[_Args]) -> None: ...
780+
""",
781+
runtime="def f1(*, a, b): pass",
782+
error=None,
783+
)
784+
yield Case(
785+
stub="def f2(**kwargs: Unpack[_Args]) -> None: ...",
786+
runtime="def f2(*, a, c): pass",
787+
error="f2",
788+
)
789+
768790
@collect_cases
769791
def test_overload(self) -> Iterator[Case]:
770792
yield Case(

0 commit comments

Comments
 (0)