Skip to content

Commit

Permalink
Add __contains__ to MapKeys
Browse files Browse the repository at this point in the history
I ran into this when I tried to pass `.keys()` to something declared
to take a `Collection`, and it failed because it lacked
`__contains__`. (I didn't actually care about `__contains__`, but mypy
did.)

So I came to add it to the stubs and discovered that it *was* actually
missing, and since the fallback is woefully slow I went and
implemented it.
  • Loading branch information
msullivan authored and elprans committed Jul 21, 2023
1 parent f797822 commit dc68581
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
14 changes: 14 additions & 0 deletions immutables/_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -2810,9 +2810,23 @@ map_new_items_view(MapObject *o)
/////////////////////////////////// _MapKeys_Type


static int
map_tp_contains(BaseMapObject *self, PyObject *key);

static int
_map_keys_tp_contains(MapView *self, PyObject *key)
{
return map_tp_contains((BaseMapObject *)self->mv_obj, key);
}

static PySequenceMethods _MapKeys_as_sequence = {
.sq_contains = (objobjproc)_map_keys_tp_contains,
};

PyTypeObject _MapKeys_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"keys",
.tp_as_sequence = &_MapKeys_as_sequence,
VIEW_TYPE_SHARED_SLOTS
};

Expand Down
1 change: 1 addition & 0 deletions immutables/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class MapKeys(Protocol[KT_co]):
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[KT_co]: ...
def __contains__(self, __key: object) -> bool: ...


class MapValues(Protocol[VT_co]):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,10 @@ def test_kwarg_named_col(self):
self.assertEqual(dict(self.Map(a=0, col=1)), {"a": 0, "col": 1})
self.assertEqual(dict(self.Map({"a": 0}, col=1)), {"a": 0, "col": 1})

def test_map_keys_contains(self):
m = self.Map(foo="bar")
self.assertTrue("foo" in m.keys())


class PyMapTest(BaseMapTest, unittest.TestCase):

Expand Down

0 comments on commit dc68581

Please sign in to comment.