cpython/Lib/test/test_asyncio/test_pep492.py

"""Tests support for new syntax introduced by PEP 492."""

import sys
import types
import unittest

from unittest import mock

import asyncio
from test.test_asyncio import utils as test_utils


def tearDownModule():
    asyncio.set_event_loop_policy(None)


# Test that asyncio.iscoroutine() uses collections.abc.Coroutine
class FakeCoro:
    def send(self, value):
        pass

    def throw(self, typ, val=None, tb=None):
        pass

    def close(self):
        pass

    def __await__(self):
        yield


class BaseTest(test_utils.TestCase):

    def setUp(self):
        super().setUp()
        self.loop = asyncio.BaseEventLoop()
        self.loop._process_events = mock.Mock()
        self.loop._selector = mock.Mock()
        self.loop._selector.select.return_value = ()
        self.set_event_loop(self.loop)


class LockTests(BaseTest):

    def test_context_manager_async_with(self):
        primitives = [
            asyncio.Lock(),
            asyncio.Condition(),
            asyncio.Semaphore(),
            asyncio.BoundedSemaphore(),
        ]

        async def test(lock):
            await asyncio.sleep(0.01)
            self.assertFalse(lock.locked())
            async with lock as _lock:
                self.assertIs(_lock, None)
                self.assertTrue(lock.locked())
                await asyncio.sleep(0.01)
                self.assertTrue(lock.locked())
            self.assertFalse(lock.locked())

        for primitive in primitives:
            self.loop.run_until_complete(test(primitive))
            self.assertFalse(primitive.locked())

    def test_context_manager_with_await(self):
        primitives = [
            asyncio.Lock(),
            asyncio.Condition(),
            asyncio.Semaphore(),
            asyncio.BoundedSemaphore(),
        ]

        async def test(lock):
            await asyncio.sleep(0.01)
            self.assertFalse(lock.locked())
            with self.assertRaisesRegex(
                TypeError,
                "can't be awaited"
            ):
                with await lock:
                    pass

        for primitive in primitives:
            self.loop.run_until_complete(test(primitive))
            self.assertFalse(primitive.locked())


class StreamReaderTests(BaseTest):

    def test_readline(self):
        DATA = b'line1\nline2\nline3'

        stream = asyncio.StreamReader(loop=self.loop)
        stream.feed_data(DATA)
        stream.feed_eof()

        async def reader():
            data = []
            async for line in stream:
                data.append(line)
            return data

        data = self.loop.run_until_complete(reader())
        self.assertEqual(data, [b'line1\n', b'line2\n', b'line3'])


class CoroutineTests(BaseTest):

    def test_iscoroutine(self):
        async def foo(): pass

        f = foo()
        try:
            self.assertTrue(asyncio.iscoroutine(f))
        finally:
            f.close() # silence warning

        self.assertTrue(asyncio.iscoroutine(FakeCoro()))

    def test_iscoroutine_generator(self):
        def foo(): yield

        self.assertFalse(asyncio.iscoroutine(foo()))

    def test_iscoroutinefunction(self):
        async def foo(): pass
        with self.assertWarns(DeprecationWarning):
            self.assertTrue(asyncio.iscoroutinefunction(foo))

    def test_async_def_coroutines(self):
        async def bar():
            return 'spam'
        async def foo():
            return await bar()

        # production mode
        data = self.loop.run_until_complete(foo())
        self.assertEqual(data, 'spam')

        # debug mode
        self.loop.set_debug(True)
        data = self.loop.run_until_complete(foo())
        self.assertEqual(data, 'spam')

    def test_debug_mode_manages_coroutine_origin_tracking(self):
        async def start():
            self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0)

        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
        self.loop.set_debug(True)
        self.loop.run_until_complete(start())
        self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)

    def test_types_coroutine(self):
        def gen():
            yield from ()
            return 'spam'

        @types.coroutine
        def func():
            return gen()

        async def coro():
            wrapper = func()
            self.assertIsInstance(wrapper, types._GeneratorWrapper)
            return await wrapper

        data = self.loop.run_until_complete(coro())
        self.assertEqual(data, 'spam')

    def test_task_print_stack(self):
        T = None

        async def foo():
            f = T.get_stack(limit=1)
            try:
                self.assertEqual(f[0].f_code.co_name, 'foo')
            finally:
                f = None

        async def runner():
            nonlocal T
            T = asyncio.ensure_future(foo(), loop=self.loop)
            await T

        self.loop.run_until_complete(runner())

    def test_double_await(self):
        async def afunc():
            await asyncio.sleep(0.1)

        async def runner():
            coro = afunc()
            t = self.loop.create_task(coro)
            try:
                await asyncio.sleep(0)
                await coro
            finally:
                t.cancel()

        self.loop.set_debug(True)
        with self.assertRaises(
                RuntimeError,
                msg='coroutine is being awaited already'):

            self.loop.run_until_complete(runner())


if __name__ == '__main__':
    unittest.main()