Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44839: Raise more specific errors in sqlite3 #27613

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion Lib/sqlite3/test/userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@

import contextlib
import functools
import gc
import io
import sys
import unittest
import unittest.mock
import gc
import sqlite3 as sqlite

from test.support import bigmemtest


def with_tracebacks(strings):
"""Convenience decorator for testing callback tracebacks."""
strings.append('Traceback')
Expand Down Expand Up @@ -69,6 +73,10 @@ def func_returnlonglong():
return 1<<31
def func_raiseexception():
5/0
def func_memoryerror():
raise MemoryError
def func_overflowerror():
raise OverflowError

def func_isstring(v):
return type(v) is str
Expand Down Expand Up @@ -187,6 +195,8 @@ def setUp(self):
self.con.create_function("returnblob", 0, func_returnblob)
self.con.create_function("returnlonglong", 0, func_returnlonglong)
self.con.create_function("raiseexception", 0, func_raiseexception)
self.con.create_function("memoryerror", 0, func_memoryerror)
self.con.create_function("overflowerror", 0, func_overflowerror)

self.con.create_function("isstring", 1, func_isstring)
self.con.create_function("isint", 1, func_isint)
Expand Down Expand Up @@ -279,6 +289,20 @@ def test_func_exception(self):
cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception')

@with_tracebacks(['func_memoryerror', 'MemoryError'])
def test_func_memory_error(self):
cur = self.con.cursor()
with self.assertRaises(MemoryError):
cur.execute("select memoryerror()")
cur.fetchone()

@with_tracebacks(['func_overflowerror', 'OverflowError'])
def test_func_overflow_error(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.DataError):
cur.execute("select overflowerror()")
cur.fetchone()

def test_param_string(self):
cur = self.con.cursor()
for text in ["foo", str()]:
Expand Down Expand Up @@ -384,6 +408,25 @@ def md5sum(t):
del x,y
gc.collect()

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=3, dry_run=False)
def test_large_text(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largetext", 0, lambda size=size: "b" * size)
with self.assertRaises(sqlite.DataError):
cur.execute("select largetext()")

@unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
@bigmemtest(size=2**31, memuse=2, dry_run=False)
def test_large_blob(self, size):
cur = self.con.cursor()
for size in 2**31-1, 2**31:
self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
with self.assertRaises(sqlite.DataError):
cur.execute("select largeblob()")


class AggregateTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
:class:`MemoryError` raised in user-defined functions will now produce a
``MemoryError`` in :mod:`sqlite3`. :class:`OverflowError` will now be converted
to :class:`~sqlite3.DataError`. Previously
:class:`~sqlite3.OperationalError` was produced in these cases.
67 changes: 31 additions & 36 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,29 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc,
return NULL;
}

// Checks the Python exception and sets the appropriate SQLite error code.
static void
set_sqlite_error(sqlite3_context *context, const char *msg)
{
assert(PyErr_Occurred());
if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
sqlite3_result_error_nomem(context);
}
else if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
sqlite3_result_error_toobig(context);
}
else {
sqlite3_result_error(context, msg, -1);
}
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
}

static void
_pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
{
Expand All @@ -645,14 +668,7 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
Py_DECREF(py_retval);
}
if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined function raised exception", -1);
set_sqlite_error(context, "user-defined function raised exception");
}

PyGILState_Release(threadstate);
Expand All @@ -676,18 +692,9 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_

if (*aggregate_instance == NULL) {
*aggregate_instance = _PyObject_CallNoArg(aggregate_class);

if (PyErr_Occurred()) {
*aggregate_instance = 0;

pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1);
if (!*aggregate_instance) {
set_sqlite_error(context,
"user-defined aggregate's '__init__' method raised error");
goto error;
}
}
Expand All @@ -706,14 +713,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
Py_DECREF(args);

if (!function_result) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1);
set_sqlite_error(context,
"user-defined aggregate's 'step' method raised error");
}

error:
Expand Down Expand Up @@ -761,14 +762,8 @@ _pysqlite_final_callback(sqlite3_context *context)
Py_DECREF(function_result);
}
if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
PyErr_Clear();
}
sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1);
set_sqlite_error(context,
"user-defined aggregate's 'finalize' method raised error");
}

/* Restore the exception (if any) of the last call to step(),
Expand Down