From 9b8f36d08a5bdffa83019f679a9c9d2ef5ca4302 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sun, 15 Jul 2018 11:07:47 +0200
Subject: [PATCH 3/3] Support yield from connect/serve on Python 3.7.
Fix #435.
(cherry picked from commit 91a376685b1ab7103d3d861ff8b02a1c00f142b1)
---
websockets/client.py | 1 +
websockets/py35/_test_client_server.py | 3 ++
websockets/server.py | 1 +
websockets/test_client_server.py | 41 ++++++++++++++++++++++++++
4 files changed, 46 insertions(+)
diff --git a/websockets/client.py b/websockets/client.py
index a86b90f..bb3009b 100644
--- a/websockets/client.py
+++ b/websockets/client.py
@@ -385,6 +385,7 @@ class Connect:
self._creating_connection = loop.create_connection(
factory, host, port, **kwds)
+ @asyncio.coroutine
def __iter__(self): # pragma: no cover
transport, protocol = yield from self._creating_connection
diff --git a/websockets/py35/_test_client_server.py b/websockets/py35/_test_client_server.py
index 5360d8d..c656dd3 100644
--- a/websockets/py35/_test_client_server.py
+++ b/websockets/py35/_test_client_server.py
@@ -39,6 +39,7 @@ class AsyncAwaitTests(unittest.TestCase):
self.loop.run_until_complete(server.wait_closed())
def test_server(self):
+
async def run_server():
# Await serve.
server = await serve(handler, 'localhost', 0)
@@ -83,6 +84,7 @@ class ContextManagerTests(unittest.TestCase):
@unittest.skipIf(
sys.version_info[:3] <= (3, 5, 0), 'this test requires Python 3.5.1+')
def test_server(self):
+
async def run_server():
# Use serve as an asynchronous context manager.
async with serve(handler, 'localhost', 0) as server:
@@ -99,6 +101,7 @@ class ContextManagerTests(unittest.TestCase):
@unittest.skipUnless(
hasattr(socket, 'AF_UNIX'), 'this test requires Unix sockets')
def test_unix_server(self):
+
async def run_server(path):
async with unix_serve(handler, path) as server:
self.assertTrue(server.sockets)
diff --git a/websockets/server.py b/websockets/server.py
index 46c80dc..86fa700 100644
--- a/websockets/server.py
+++ b/websockets/server.py
@@ -729,6 +729,7 @@ class Serve:
self._creating_server = creating_server
self.ws_server = ws_server
+ @asyncio.coroutine
def __iter__(self): # pragma: no cover
server = yield from self._creating_server
self.ws_server.wrap(server)
diff --git a/websockets/test_client_server.py b/websockets/test_client_server.py
index a3e1e92..6c25784 100644
--- a/websockets/test_client_server.py
+++ b/websockets/test_client_server.py
@@ -24,6 +24,7 @@ from .extensions.permessage_deflate import (
)
from .handshake import build_response
from .http import USER_AGENT, read_response
+from .protocol import State
from .server import *
from .test_protocol import MS
@@ -1056,6 +1057,46 @@ class ClientServerOriginTests(unittest.TestCase):
self.loop.run_until_complete(server.wait_closed())
+class YieldFromTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_client(self):
+ start_server = serve(handler, 'localhost', 0)
+ server = self.loop.run_until_complete(start_server)
+
+ @asyncio.coroutine
+ def run_client():
+ # Yield from connect.
+ client = yield from connect(get_server_uri(server))
+ self.assertEqual(client.state, State.OPEN)
+ yield from client.close()
+ self.assertEqual(client.state, State.CLOSED)
+
+ self.loop.run_until_complete(run_client())
+
+ server.close()
+ self.loop.run_until_complete(server.wait_closed())
+
+ def test_server(self):
+
+ @asyncio.coroutine
+ def run_server():
+ # Yield from serve.
+ server = yield from serve(handler, 'localhost', 0)
+ self.assertTrue(server.sockets)
+ server.close()
+ yield from server.wait_closed()
+ self.assertFalse(server.sockets)
+
+ self.loop.run_until_complete(run_server())
+
+
if sys.version_info[:2] >= (3, 5): # pragma: no cover
from .py35._test_client_server import AsyncAwaitTests # noqa
from .py35._test_client_server import ContextManagerTests # noqa
--
2.18.0