# Adapted with permission from the EdgeDB project;
# license: PSFL.
import sys
import gc
import asyncio
import contextvars
import contextlib
from asyncio import taskgroups
import unittest
import warnings
from test.test_asyncio.utils import await_without_task
# To prevent a warning "test altered the execution environment"
def tearDownModule():
asyncio.set_event_loop_policy(None)
class MyExc(Exception):
pass
class MyBaseExc(BaseException):
pass
def get_error_types(eg):
return {type(exc) for exc in eg.exceptions}
def no_other_refs():
# due to gh-124392 coroutines now refer to their locals
coro = asyncio.current_task().get_coro()
frame = sys._getframe(1)
while coro.cr_frame != frame:
coro = coro.cr_await
return [coro]
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
async def test_taskgroup_01(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_02(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_03(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
# cancel t1 explicitly, i.e. everything should continue
# working as expected.
t1.cancel()
t2 = g.create_task(foo2())
self.assertTrue(t1.cancelled())
self.assertEqual(t2.result(), 11)
async def test_taskgroup_04(self):
NUM = 0
t2_cancel = False
t2 = None
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, t2
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
t2 = g.create_task(foo2())
NUM += 10
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(t2.cancelled())
async def test_cancel_children_on_child_error(self):
# When a child task raises an error, the rest of the children
# are cancelled and the errors are gathered into an EG.
NUM = 0
t2_cancel = False
runner_cancel = False
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, runner_cancel
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo2())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
runner_cancel = True
raise
NUM += 10
# The 3 foo1 sub tasks can be racy when the host is busy - if the
# cancellation happens in the middle, we'll see partial sub errors here
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(runner_cancel)
async def test_cancellation(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError) as cm:
await r
self.assertEqual(NUM, 5)
async def test_taskgroup_07(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 15)
async def test_taskgroup_08(self):
async def foo():
try:
await asyncio.sleep(10)
finally:
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
await asyncio.sleep(10)
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
async def test_taskgroup_09(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
await asyncio.sleep(0.1)
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_10(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_11(self):
async def foo():
try:
await asyncio.sleep(10)
finally:
1 / 0
async def runner():
async with taskgroups.TaskGroup():
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
await asyncio.sleep(10)
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
async def test_taskgroup_12(self):
async def foo():
try:
await asyncio.sleep(10)
finally:
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(asyncio.sleep(10))
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
await asyncio.sleep(10)
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
async def test_taskgroup_13(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(0.1))
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(10))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ValueError})
async def test_taskgroup_14(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(10))
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(0.1))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
async def test_taskgroup_15(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
async def test_taskgroup_16(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def nested_runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
async def runner():
t = asyncio.create_task(nested_runner())
await t
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
async def test_taskgroup_17(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 10)
async def test_taskgroup_18(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
# This isn't a good idea, but we have to support
# this weird case.
raise MyExc
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t),{MyExc})
else:
self.fail('ExceptionGroup was not raised')
self.assertEqual(NUM, 10)
async def test_taskgroup_19(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
r = asyncio.create_task(runner())
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
else:
self.fail('TasgGroupError was not raised')
async def test_taskgroup_20(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise KeyboardInterrupt
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_20a(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyBaseExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(
get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
)
async def _test_taskgroup_21(self):
# This test doesn't work as asyncio, currently, doesn't
# correctly propagate KeyboardInterrupt (or SystemExit) --
# those cause the event loop itself to crash.
# (Compare to the previous (passing) test -- that one raises
# a plain exception but raises KeyboardInterrupt in nested();
# this test does it the other way around.)
async def crash_soon():
await asyncio.sleep(0.1)
raise KeyboardInterrupt
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_21a(self):
async def crash_soon():
await asyncio.sleep(0.1)
raise MyBaseExc
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
async def test_taskgroup_22(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo2())
r = asyncio.create_task(runner())
await asyncio.sleep(0.05)
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_23(self):
async def do_job(delay):
await asyncio.sleep(delay)
async with taskgroups.TaskGroup() as g:
for count in range(10):
await asyncio.sleep(0.1)
g.create_task(do_job(0.3))
if count == 5:
self.assertLess(len(g._tasks), 5)
await asyncio.sleep(1.35)
self.assertEqual(len(g._tasks), 0)
async def test_taskgroup_24(self):
async def root(g):
await asyncio.sleep(0.1)
g.create_task(coro1(0.1))
g.create_task(coro1(0.2))
async def coro1(delay):
await asyncio.sleep(delay)
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(root(g))
await runner()
async def test_taskgroup_25(self):
nhydras = 0
async def hydra(g):
nonlocal nhydras
nhydras += 1
await asyncio.sleep(0.01)
g.create_task(hydra(g))
g.create_task(hydra(g))
async def hercules():
while nhydras < 10:
await asyncio.sleep(0.015)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(hydra(g))
g.create_task(hercules())
with self.assertRaises(ExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertGreaterEqual(nhydras, 10)
async def test_taskgroup_task_name(self):
async def coro():
await asyncio.sleep(0)
async with taskgroups.TaskGroup() as g:
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")
async def test_taskgroup_task_context(self):
cvar = contextvars.ContextVar('cvar')
async def coro(val):
await asyncio.sleep(0)
cvar.set(val)
async with taskgroups.TaskGroup() as g:
ctx = contextvars.copy_context()
self.assertIsNone(ctx.get(cvar))
t1 = g.create_task(coro(1), context=ctx)
await t1
self.assertEqual(1, ctx.get(cvar))
t2 = g.create_task(coro(2), context=ctx)
await t2
self.assertEqual(2, ctx.get(cvar))
async def test_taskgroup_no_create_task_after_failure(self):
async def coro1():
await asyncio.sleep(0.001)
1 / 0
async def coro2(g):
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
with self.assertRaises(RuntimeError):
g.create_task(coro1())
with self.assertRaises(ExceptionGroup) as cm:
async with taskgroups.TaskGroup() as g:
g.create_task(coro1())
g.create_task(coro2(g))
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
async def test_taskgroup_context_manager_exit_raises(self):
# See https://github.com/python/cpython/issues/95289
class CustomException(Exception):
pass
async def raise_exc():
raise CustomException
@contextlib.asynccontextmanager
async def database():
try:
yield
finally:
raise CustomException
async def main():
task = asyncio.current_task()
try:
async with taskgroups.TaskGroup() as tg:
async with database():
tg.create_task(raise_exc())
await asyncio.sleep(1)
except* CustomException as err:
self.assertEqual(task.cancelling(), 0)
self.assertEqual(len(err.exceptions), 2)
else:
self.fail('CustomException not raised')
await asyncio.create_task(main())
async def test_taskgroup_already_entered(self):
tg = taskgroups.TaskGroup()
async with tg:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_double_enter(self):
tg = taskgroups.TaskGroup()
async with tg:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with tg:
pass
async def test_taskgroup_finished(self):
async def create_task_after_tg_finish():
tg = taskgroups.TaskGroup()
async with tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
# Make sure the coroutine was closed when submitted to the inactive tg
# (if not closed, a RuntimeWarning should have been raised)
with warnings.catch_warnings(record=True) as w:
await create_task_after_tg_finish()
self.assertEqual(len(w), 0)
async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
with self.assertRaisesRegex(RuntimeError, "parent task"):
await await_without_task(tg.__aenter__())
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
def test_coro_closed_when_tg_closed(self):
async def run_coro_after_tg_closes():
async with taskgroups.TaskGroup() as tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
loop = asyncio.get_event_loop()
loop.run_until_complete(run_coro_after_tg_closes())
async def test_cancelling_level_preserved(self):
async def raise_after(t, e):
await asyncio.sleep(t)
raise e()
try:
async with asyncio.TaskGroup() as tg:
tg.create_task(raise_after(0.0, RuntimeError))
except* RuntimeError:
pass
self.assertEqual(asyncio.current_task().cancelling(), 0)
async def test_nested_groups_both_cancelled(self):
async def raise_after(t, e):
await asyncio.sleep(t)
raise e()
try:
async with asyncio.TaskGroup() as outer_tg:
try:
async with asyncio.TaskGroup() as inner_tg:
inner_tg.create_task(raise_after(0, RuntimeError))
outer_tg.create_task(raise_after(0, ValueError))
except* RuntimeError:
pass
else:
self.fail("RuntimeError not raised")
self.assertEqual(asyncio.current_task().cancelling(), 1)
except* ValueError:
pass
else:
self.fail("ValueError not raised")
self.assertEqual(asyncio.current_task().cancelling(), 0)
async def test_error_and_cancel(self):
event = asyncio.Event()
async def raise_error():
event.set()
await asyncio.sleep(0)
raise RuntimeError()
async def inner():
try:
async with taskgroups.TaskGroup() as tg:
tg.create_task(raise_error())
await asyncio.sleep(1)
self.fail("Sleep in group should have been cancelled")
except* RuntimeError:
self.assertEqual(asyncio.current_task().cancelling(), 1)
self.assertEqual(asyncio.current_task().cancelling(), 1)
await asyncio.sleep(1)
self.fail("Sleep after group should have been cancelled")
async def outer():
t = asyncio.create_task(inner())
await event.wait()
self.assertEqual(t.cancelling(), 0)
t.cancel()
self.assertEqual(t.cancelling(), 1)
with self.assertRaises(asyncio.CancelledError):
await t
self.assertTrue(t.cancelled())
await outer()
async def test_exception_refcycles_direct(self):
"""Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
tg = asyncio.TaskGroup()
exc = None
class _Done(Exception):
pass
try:
async with tg:
raise _Done
except ExceptionGroup as e:
exc = e
self.assertIsNotNone(exc)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
async def test_exception_refcycles_errors(self):
"""Test that TaskGroup deletes self._errors, and __aexit__ args"""
tg = asyncio.TaskGroup()
exc = None
class _Done(Exception):
pass
try:
async with tg:
raise _Done
except* _Done as excs:
exc = excs.exceptions[0]
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
async def test_exception_refcycles_parent_task(self):
"""Test that TaskGroup deletes self._parent_task"""
tg = asyncio.TaskGroup()
exc = None
class _Done(Exception):
pass
async def coro_fn():
async with tg:
raise _Done
try:
async with asyncio.TaskGroup() as tg2:
tg2.create_task(coro_fn())
except* _Done as excs:
exc = excs.exceptions[0].exceptions[0]
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
async def test_exception_refcycles_propagate_cancellation_error(self):
"""Test that TaskGroup deletes propagate_cancellation_error"""
tg = asyncio.TaskGroup()
exc = None
try:
async with asyncio.timeout(-1):
async with tg:
await asyncio.sleep(0)
except TimeoutError as e:
exc = e.__cause__
self.assertIsInstance(exc, asyncio.CancelledError)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
async def test_exception_refcycles_base_error(self):
"""Test that TaskGroup deletes self._base_error"""
class MyKeyboardInterrupt(KeyboardInterrupt):
pass
tg = asyncio.TaskGroup()
exc = None
try:
async with tg:
raise MyKeyboardInterrupt
except MyKeyboardInterrupt as e:
exc = e
self.assertIsNotNone(exc)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
if __name__ == "__main__":
unittest.main()