From 9b8f36d08a5bdffa83019f679a9c9d2ef5ca4302 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Jul 2018 11:07:47 +0200 Subject: [PATCH 3/3] Support yield from connect/serve on Python 3.7. Fix #435. (cherry picked from commit 91a376685b1ab7103d3d861ff8b02a1c00f142b1) --- websockets/client.py | 1 + websockets/py35/_test_client_server.py | 3 ++ websockets/server.py | 1 + websockets/test_client_server.py | 41 ++++++++++++++++++++++++++ 4 files changed, 46 insertions(+) diff --git a/websockets/client.py b/websockets/client.py index a86b90f..bb3009b 100644 --- a/websockets/client.py +++ b/websockets/client.py @@ -385,6 +385,7 @@ class Connect: self._creating_connection = loop.create_connection( factory, host, port, **kwds) + @asyncio.coroutine def __iter__(self): # pragma: no cover transport, protocol = yield from self._creating_connection diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py index 5360d8d..c656dd3 100644 --- a/websockets/py35/_test_client_server.py +++ b/websockets/py35/_test_client_server.py @@ -39,6 +39,7 @@ class AsyncAwaitTests(unittest.TestCase): self.loop.run_until_complete(server.wait_closed()) def test_server(self): + async def run_server(): # Await serve. server = await serve(handler, 'localhost', 0) @@ -83,6 +84,7 @@ class ContextManagerTests(unittest.TestCase): @unittest.skipIf( sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+') def test_server(self): + async def run_server(): # Use serve as an asynchronous context manager. async with serve(handler, 'localhost', 0) as server: @@ -99,6 +101,7 @@ class ContextManagerTests(unittest.TestCase): @unittest.skipUnless( hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets') def test_unix_server(self): + async def run_server(path): async with unix_serve(handler, path) as server: self.assertTrue(server.sockets) diff --git a/websockets/server.py b/websockets/server.py index 46c80dc..86fa700 100644 --- a/websockets/server.py +++ b/websockets/server.py @@ -729,6 +729,7 @@ class Serve: self._creating_server = creating_server self.ws_server = ws_server + @asyncio.coroutine def __iter__(self): # pragma: no cover server = yield from self._creating_server self.ws_server.wrap(server) diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py index a3e1e92..6c25784 100644 --- a/websockets/test_client_server.py +++ b/websockets/test_client_server.py @@ -24,6 +24,7 @@ from .extensions.permessage_deflate import ( ) from .handshake import build_response from .http import USER_AGENT, read_response +from .protocol import State from .server import * from .test_protocol import MS @@ -1056,6 +1057,46 @@ class ClientServerOriginTests(unittest.TestCase): self.loop.run_until_complete(server.wait_closed()) +class YieldFromTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_client(self): + start_server = serve(handler, 'localhost', 0) + server = self.loop.run_until_complete(start_server) + + @asyncio.coroutine + def run_client(): + # Yield from connect. + client = yield from connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + yield from client.close() + self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + + server.close() + self.loop.run_until_complete(server.wait_closed()) + + def test_server(self): + + @asyncio.coroutine + def run_server(): + # Yield from serve. + server = yield from serve(handler, 'localhost', 0) + self.assertTrue(server.sockets) + server.close() + yield from server.wait_closed() + self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + + if sys.version_info[:2] >= (3, 5): # pragma: no cover from .py35._test_client_server import AsyncAwaitTests # noqa from .py35._test_client_server import ContextManagerTests # noqa -- 2.18.0