diff --git a/immutables/_map.c b/immutables/_map.c index 7acf6c7e..5d961779 100644 --- a/immutables/_map.c +++ b/immutables/_map.c @@ -2810,9 +2810,23 @@ map_new_items_view(MapObject *o) /////////////////////////////////// _MapKeys_Type +static int +map_tp_contains(BaseMapObject *self, PyObject *key); + +static int +_map_keys_tp_contains(MapView *self, PyObject *key) +{ + return map_tp_contains((BaseMapObject *)self->mv_obj, key); +} + +static PySequenceMethods _MapKeys_as_sequence = { + .sq_contains = (objobjproc)_map_keys_tp_contains, +}; + PyTypeObject _MapKeys_Type = { PyVarObject_HEAD_INIT(NULL, 0) "keys", + .tp_as_sequence = &_MapKeys_as_sequence, VIEW_TYPE_SHARED_SLOTS }; diff --git a/immutables/_protocols.py b/immutables/_protocols.py index de87d230..49d8ee1e 100644 --- a/immutables/_protocols.py +++ b/immutables/_protocols.py @@ -32,6 +32,7 @@ class MapKeys(Protocol[KT_co]): def __len__(self) -> int: ... def __iter__(self) -> Iterator[KT_co]: ... + def __contains__(self, __key: object) -> bool: ... class MapValues(Protocol[VT_co]): diff --git a/tests/test_map.py b/tests/test_map.py index 4640029d..4caef50a 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1395,6 +1395,10 @@ def test_kwarg_named_col(self): self.assertEqual(dict(self.Map(a=0, col=1)), {"a": 0, "col": 1}) self.assertEqual(dict(self.Map({"a": 0}, col=1)), {"a": 0, "col": 1}) + def test_map_keys_contains(self): + m = self.Map(foo="bar") + self.assertTrue("foo" in m.keys()) + class PyMapTest(BaseMapTest, unittest.TestCase):