From 2ad433104978729e9b733681869feb902fd7e591 Mon Sep 17 00:00:00 2001 From: Hajime Senuma Date: Tue, 17 Sep 2024 11:48:59 +0900 Subject: [PATCH] Add value range check for seed --- src/mmh3/mmh3module.c | 12 ++++++++++-- tests/test_invalid_inputs.py | 24 ++++++++++++++++++++++++ tox.ini | 2 +- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 tests/test_invalid_inputs.py diff --git a/src/mmh3/mmh3module.c b/src/mmh3/mmh3module.c index 53ecda8..7c7dc39 100644 --- a/src/mmh3/mmh3module.c +++ b/src/mmh3/mmh3module.c @@ -51,10 +51,18 @@ typedef unsigned __int64 uint64_t; Py_TYPE(args[1])->tp_name); \ return NULL; \ } \ - seed = (uint32_t)PyLong_AsUnsignedLong(args[1]); \ - if (seed == (unsigned long)-1 && PyErr_Occurred()) { \ + const unsigned long seed_tmp = PyLong_AsUnsignedLong(args[1]); \ + if (seed_tmp == -1 && PyErr_Occurred()) { \ + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { \ + PyErr_SetString(PyExc_ValueError, "seed is out of range"); \ + return NULL; \ + } \ + } \ + if (seed_tmp > 0xFFFFFFFF) { \ + PyErr_SetString(PyExc_ValueError, "seed is out of range"); \ return NULL; \ } \ + seed = (uint32_t)seed_tmp; \ } //----------------------------------------------------------------------------- diff --git a/tests/test_invalid_inputs.py b/tests/test_invalid_inputs.py new file mode 100644 index 0000000..9e802d9 --- /dev/null +++ b/tests/test_invalid_inputs.py @@ -0,0 +1,24 @@ +# pylint: disable=missing-module-docstring, missing-function-docstring +# pylint: disable=no-value-for-parameter, too-many-function-args +import mmh3 +import pytest + + +def test_mmh3_32_digest_raises_typeerror() -> None: + with pytest.raises(TypeError): + mmh3.mmh3_32_digest() + with pytest.raises(TypeError): + mmh3.mmh3_32_digest(b"hello, world", 42, 1234) + with pytest.raises(TypeError): + mmh3.mmh3_32_digest("hello, world") + with pytest.raises(TypeError): + mmh3.mmh3_32_digest(b"hello, world", "42") + with pytest.raises(TypeError): + mmh3.mmh3_32_digest([1, 2, 3], 42) + + +def test_mmh3_32_digest_raises_valueerror() -> None: + with pytest.raises(ValueError): + mmh3.mmh3_32_digest(b"hello, world", -1) + with pytest.raises(ValueError): + mmh3.mmh3_32_digest(b"hello, world", 2**32) diff --git a/tox.ini b/tox.ini index 13066fd..2f64ec3 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,7 @@ description = run unit tests commands_pre = pip install ".[test]" commands = - pytest + pytest {posargs} [testenv:lint] description = run linters with formatting