import contextlib
import functools
import io
import re
import sqlite3
import test.support
# Helper for temporary memory databases
def memory_database(*args, **kwargs):
cx = sqlite3.connect(":memory:", *args, **kwargs)
return contextlib.closing(cx)
# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
try:
_prev = cx.setlimit(category, limit)
yield limit
finally:
cx.setlimit(category, _prev)
def with_tracebacks(exc, regex="", name=""):
"""Convenience decorator for testing callback tracebacks."""
def decorator(func):
_regex = re.compile(regex) if regex else None
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with test.support.catch_unraisable_exception() as cm:
# First, run the test with traceback enabled.
with check_tracebacks(self, cm, exc, _regex, name):
func(self, *args, **kwargs)
# Then run the test with traceback disabled.
func(self, *args, **kwargs)
return wrapper
return decorator
@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
"""Convenience context manager for testing callback tracebacks."""
sqlite3.enable_callback_tracebacks(True)
try:
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
yield
self.assertEqual(cm.unraisable.exc_type, exc)
if regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(regex.search(msg))
if obj_name:
self.assertEqual(cm.unraisable.object.__name__, obj_name)
finally:
sqlite3.enable_callback_tracebacks(False)
class MemoryDatabaseMixin:
def setUp(self):
self.con = sqlite3.connect(":memory:")
self.cur = self.con.cursor()
def tearDown(self):
self.cur.close()
self.con.close()
@property
def cx(self):
return self.con
@property
def cu(self):
return self.cur