diff --git a/Doc/c-api/dict.rst b/Doc/c-api/dict.rst index 7bebea0c97de5af..e5f28b59a701e00 100644 --- a/Doc/c-api/dict.rst +++ b/Doc/c-api/dict.rst @@ -246,17 +246,32 @@ Dictionary Objects of error (e.g. no more watcher IDs available), return ``-1`` and set an exception. + .. versionadded:: 3.12 + .. c:function:: int PyDict_ClearWatcher(int watcher_id) Clear watcher identified by *watcher_id* previously returned from :c:func:`PyDict_AddWatcher`. Return ``0`` on success, ``-1`` on error (e.g. if the given *watcher_id* was never registered.) + .. versionadded:: 3.12 + .. c:function:: int PyDict_Watch(int watcher_id, PyObject *dict) Mark dictionary *dict* as watched. The callback granted *watcher_id* by :c:func:`PyDict_AddWatcher` will be called when *dict* is modified or - deallocated. + deallocated. Return ``0`` on success or ``-1`` on error. + + .. versionadded:: 3.12 + +.. c:function:: int PyDict_Unwatch(int watcher_id, PyObject *dict) + + Mark dictionary *dict* as no longer watched. The callback granted + *watcher_id* by :c:func:`PyDict_AddWatcher` will no longer be called when + *dict* is modified or deallocated. The dict must previously have been + watched by this watcher. Return ``0`` on success or ``-1`` on error. + + .. versionadded:: 3.12 .. c:type:: PyDict_WatchEvent @@ -264,6 +279,8 @@ Dictionary Objects ``PyDict_EVENT_MODIFIED``, ``PyDict_EVENT_DELETED``, ``PyDict_EVENT_CLONED``, ``PyDict_EVENT_CLEARED``, or ``PyDict_EVENT_DEALLOCATED``. + .. versionadded:: 3.12 + .. c:type:: int (*PyDict_WatchCallback)(PyDict_WatchEvent event, PyObject *dict, PyObject *key, PyObject *new_value) Type of a dict watcher callback function. @@ -289,3 +306,5 @@ Dictionary Objects If the callback returns with an exception set, it must return ``-1``; this exception will be printed as an unraisable exception using :c:func:`PyErr_WriteUnraisable`. Otherwise it should return ``0``. + + .. versionadded:: 3.12 diff --git a/Include/cpython/dictobject.h b/Include/cpython/dictobject.h index f8a74a597b0ea2a..2dff59ef0b8a6b8 100644 --- a/Include/cpython/dictobject.h +++ b/Include/cpython/dictobject.h @@ -106,3 +106,4 @@ PyAPI_FUNC(int) PyDict_ClearWatcher(int watcher_id); // Mark given dictionary as "watched" (callback will be called if it is modified) PyAPI_FUNC(int) PyDict_Watch(int watcher_id, PyObject* dict); +PyAPI_FUNC(int) PyDict_Unwatch(int watcher_id, PyObject* dict); diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index 19367dfcc1ccbba..17425050ce00c06 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -20,6 +20,7 @@ import weakref from test import support from test.support import MISSING_C_DOCSTRINGS +from test.support import catch_unraisable_exception from test.support import import_helper from test.support import threading_helper from test.support import warnings_helper @@ -1421,6 +1422,9 @@ def assert_events(self, expected): def watch(self, wid, d): _testcapi.watch_dict(wid, d) + def unwatch(self, wid, d): + _testcapi.unwatch_dict(wid, d) + def test_set_new_item(self): d = {} with self.watcher() as wid: @@ -1477,27 +1481,24 @@ def test_dealloc(self): del d self.assert_events(["dealloc"]) + def test_unwatch(self): + d = {} + with self.watcher() as wid: + self.watch(wid, d) + d["foo"] = "bar" + self.unwatch(wid, d) + d["hmm"] = "baz" + self.assert_events(["new:foo:bar"]) + def test_error(self): d = {} - unraisables = [] - def unraisable_hook(unraisable): - unraisables.append(unraisable) with self.watcher(kind=self.ERROR) as wid: self.watch(wid, d) - orig_unraisable_hook = sys.unraisablehook - sys.unraisablehook = unraisable_hook - try: + with catch_unraisable_exception() as cm: d["foo"] = "bar" - finally: - sys.unraisablehook = orig_unraisable_hook + self.assertIs(cm.unraisable.object, d) + self.assertEqual(str(cm.unraisable.exc_value), "boom!") self.assert_events([]) - self.assertEqual(len(unraisables), 1) - unraisable = unraisables[0] - self.assertIs(unraisable.object, d) - self.assertEqual(str(unraisable.exc_value), "boom!") - # avoid leaking reference cycles - del unraisable - del unraisables def test_two_watchers(self): d1 = {} @@ -1522,11 +1523,38 @@ def test_watch_out_of_range_watcher_id(self): with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"): self.watch(8, d) # DICT_MAX_WATCHERS = 8 - def test_unassigned_watcher_id(self): + def test_watch_unassigned_watcher_id(self): d = {} with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"): self.watch(1, d) + def test_unwatch_non_dict(self): + with self.watcher() as wid: + with self.assertRaisesRegex(ValueError, r"Cannot watch non-dictionary"): + self.unwatch(wid, 1) + + def test_unwatch_out_of_range_watcher_id(self): + d = {} + with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID -1"): + self.unwatch(-1, d) + with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"): + self.unwatch(8, d) # DICT_MAX_WATCHERS = 8 + + def test_unwatch_unassigned_watcher_id(self): + d = {} + with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"): + self.unwatch(1, d) + + def test_clear_out_of_range_watcher_id(self): + with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID -1"): + self.clear_watcher(-1) + with self.assertRaisesRegex(ValueError, r"Invalid dict watcher ID 8"): + self.clear_watcher(8) # DICT_MAX_WATCHERS = 8 + + def test_clear_unassigned_watcher_id(self): + with self.assertRaisesRegex(ValueError, r"No dict watcher set for ID 1"): + self.clear_watcher(1) + if __name__ == "__main__": unittest.main() diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 28fb43dce4c6cbc..173d7c2cb805308 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -5296,6 +5296,20 @@ watch_dict(PyObject *self, PyObject *args) Py_RETURN_NONE; } +static PyObject * +unwatch_dict(PyObject *self, PyObject *args) +{ + PyObject *dict; + int watcher_id; + if (!PyArg_ParseTuple(args, "iO", &watcher_id, &dict)) { + return NULL; + } + if (PyDict_Unwatch(watcher_id, dict)) { + return NULL; + } + Py_RETURN_NONE; +} + static PyObject * get_dict_watcher_events(PyObject *self, PyObject *Py_UNUSED(args)) { @@ -5904,6 +5918,7 @@ static PyMethodDef TestMethods[] = { {"add_dict_watcher", add_dict_watcher, METH_O, NULL}, {"clear_dict_watcher", clear_dict_watcher, METH_O, NULL}, {"watch_dict", watch_dict, METH_VARARGS, NULL}, + {"unwatch_dict", unwatch_dict, METH_VARARGS, NULL}, {"get_dict_watcher_events", get_dict_watcher_events, METH_NOARGS, NULL}, {NULL, NULL} /* sentinel */ }; diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 6542b1803ffa2e8..97007479b1be915 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -5720,6 +5720,20 @@ uint32_t _PyDictKeys_GetVersionForCurrentState(PyDictKeysObject *dictkeys) return v; } +static inline int +validate_watcher_id(PyInterpreterState *interp, int watcher_id) +{ + if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) { + PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id); + return -1; + } + if (!interp->dict_watchers[watcher_id]) { + PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id); + return -1; + } + return 0; +} + int PyDict_Watch(int watcher_id, PyObject* dict) { @@ -5727,16 +5741,26 @@ PyDict_Watch(int watcher_id, PyObject* dict) PyErr_SetString(PyExc_ValueError, "Cannot watch non-dictionary"); return -1; } - if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) { - PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id); + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (validate_watcher_id(interp, watcher_id)) { + return -1; + } + ((PyDictObject*)dict)->ma_version_tag |= (1LL << watcher_id); + return 0; +} + +int +PyDict_Unwatch(int watcher_id, PyObject* dict) +{ + if (!PyDict_Check(dict)) { + PyErr_SetString(PyExc_ValueError, "Cannot watch non-dictionary"); return -1; } PyInterpreterState *interp = _PyInterpreterState_GET(); - if (!interp->dict_watchers[watcher_id]) { - PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id); + if (validate_watcher_id(interp, watcher_id)) { return -1; } - ((PyDictObject*)dict)->ma_version_tag |= (1LL << watcher_id); + ((PyDictObject*)dict)->ma_version_tag &= ~(1LL << watcher_id); return 0; } @@ -5759,13 +5783,8 @@ PyDict_AddWatcher(PyDict_WatchCallback callback) int PyDict_ClearWatcher(int watcher_id) { - if (watcher_id < 0 || watcher_id >= DICT_MAX_WATCHERS) { - PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id); - return -1; - } PyInterpreterState *interp = _PyInterpreterState_GET(); - if (!interp->dict_watchers[watcher_id]) { - PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id); + if (validate_watcher_id(interp, watcher_id)) { return -1; } interp->dict_watchers[watcher_id] = NULL;