diff --git a/python/qemu/aqmp/legacy.py b/python/qemu/aqmp/legacy.py index 6baa5f3409..46026e9fdc 100644 --- a/python/qemu/aqmp/legacy.py +++ b/python/qemu/aqmp/legacy.py @@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol): self._timeout: Optional[float] = None if server: - self._aqmp._bind_hack(address) # pylint: disable=protected-access + self._sync(self._aqmp.start_server(self._address)) _T = TypeVar('_T') @@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol): self._aqmp.await_greeting = True self._aqmp.negotiate = True - self._sync( - self._aqmp.accept(self._address), - timeout - ) + self._sync(self._aqmp.accept(), timeout) ret = self._get_greeting() assert ret is not None diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py index 33358f5cd7..36fae57f27 100644 --- a/python/qemu/aqmp/protocol.py +++ b/python/qemu/aqmp/protocol.py @@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient` class. """ +# It's all the docstrings ... ! It's long for a good reason ^_^; +# pylint: disable=too-many-lines + import asyncio from asyncio import StreamReader, StreamWriter from enum import Enum from functools import wraps import logging -import socket from ssl import SSLContext from typing import ( Any, @@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]): self._runstate = Runstate.IDLE self._runstate_changed: Optional[asyncio.Event] = None - # Workaround for bind() - self._sock: Optional[socket.socket] = None + # Server state for start_server() and _incoming() + self._server: Optional[asyncio.AbstractServer] = None + self._accepted: Optional[asyncio.Event] = None def __repr__(self) -> str: cls_name = type(self).__name__ @@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]): @upper_half @require(Runstate.IDLE) - async def accept(self, address: SocketAddrT, - ssl: Optional[SSLContext] = None) -> None: + async def start_server_and_accept( + self, address: SocketAddrT, + ssl: Optional[SSLContext] = None + ) -> None: """ Accept a connection and begin processing message queues. If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + This method is precisely equivalent to calling `start_server()` + followed by `accept()`. :param address: - Address to listen to; UNIX socket path or TCP address/port. + Address to listen on; UNIX socket path or TCP address/port. :param ssl: SSL context to use, if any. :raise StateError: When the `Runstate` is not `IDLE`. - :raise ConnectError: If a connection could not be accepted. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. """ - await self._new_session(address, ssl, accept=True) + await self.start_server(address, ssl) + await self.accept() + assert self.runstate == Runstate.RUNNING + + @upper_half + @require(Runstate.IDLE) + async def start_server(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: + """ + Start listening for an incoming connection, but do not wait for a peer. + + This method starts listening for an incoming connection, but + does not block waiting for a peer. This call will return + immediately after binding and listening on a socket. A later + call to `accept()` must be made in order to finalize the + incoming connection. + + :param address: + Address to listen on; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise StateError: When the `Runstate` is not `IDLE`. + :raise ConnectError: + When the server could not start listening on this address. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError`. + """ + await self._session_guard( + self._do_start_server(address, ssl), + 'Failed to establish connection') + assert self.runstate == Runstate.CONNECTING + + @upper_half + @require(Runstate.CONNECTING) + async def accept(self) -> None: + """ + Accept an incoming connection and begin processing message queues. + + If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + + :raise StateError: When the `Runstate` is not `CONNECTING`. + :raise QMPError: When `start_server()` was not called yet. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. + """ + if self._accepted is None: + raise QMPError("Cannot call accept() before start_server().") + await self._session_guard( + self._do_accept(), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING @upper_half @require(Runstate.IDLE) @@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]): :param ssl: SSL context to use, if any. :raise StateError: When the `Runstate` is not `IDLE`. - :raise ConnectError: If a connection cannot be made to the server. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. """ - await self._new_session(address, ssl) + await self._session_guard( + self._do_connect(address, ssl), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING @upper_half async def disconnect(self) -> None: @@ -317,6 +401,62 @@ class AsyncProtocol(Generic[T]): # Section: Session machinery # -------------------------- + async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None: + """ + Async guard function used to roll back to `IDLE` on any error. + + On any Exception, the state machine will be reset back to + `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but + `BaseException` events will be left alone (This includes + asyncio.CancelledError, even prior to Python 3.8). + + :param error_message: + Human-readable string describing what connection phase failed. + + :raise BaseException: + When `BaseException` occurs in the guarded block. + :raise ConnectError: + When any other error is encountered in the guarded block. + """ + # Note: After Python 3.6 support is removed, this should be an + # @asynccontextmanager instead of accepting a callback. + try: + await coro + except BaseException as err: + self.logger.error("%s: %s", emsg, exception_summary(err)) + self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) + try: + # Reset the runstate back to IDLE. + await self.disconnect() + except: + # We don't expect any Exceptions from the disconnect function + # here, because we failed to connect in the first place. + # The disconnect() function is intended to perform + # only cannot-fail cleanup here, but you never know. + emsg = ( + "Unexpected bottom half exception. " + "This is a bug in the QMP library. " + "Please report it to and " + "CC: John Snow ." + ) + self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) + raise + + # CancelledError is an Exception with special semantic meaning; + # We do NOT want to wrap it up under ConnectError. + # NB: CancelledError is not a BaseException before Python 3.8 + if isinstance(err, asyncio.CancelledError): + raise + + # Any other kind of error can be treated as some kind of connection + # failure broadly. Inspect the 'exc' field to explore the root + # cause in greater detail. + if isinstance(err, Exception): + raise ConnectError(emsg, err) from err + + # Raise BaseExceptions un-wrapped, they're more important. + raise + @property def _runstate_event(self) -> asyncio.Event: # asyncio.Event() objects should not be created prior to entrance into @@ -343,127 +483,64 @@ class AsyncProtocol(Generic[T]): self._runstate_event.set() self._runstate_event.clear() - @upper_half - async def _new_session(self, - address: SocketAddrT, - ssl: Optional[SSLContext] = None, - accept: bool = False) -> None: + @bottom_half + async def _stop_server(self) -> None: """ - Establish a new connection and initialize the session. - - Connect or accept a new connection, then begin the protocol - session machinery. If this call fails, `runstate` is guaranteed - to be set back to `IDLE`. - - :param address: - Address to connect to/listen on; - UNIX socket path or TCP address/port. - :param ssl: SSL context to use, if any. - :param accept: Accept a connection instead of connecting when `True`. - - :raise ConnectError: - When a connection or session cannot be established. - - This exception will wrap a more concrete one. In most cases, - the wrapped exception will be `OSError` or `EOFError`. If a - protocol-level failure occurs while establishing a new - session, the wrapped error may also be an `QMPError`. + Stop listening for / accepting new incoming connections. """ - assert self.runstate == Runstate.IDLE + if self._server is None: + return try: - phase = "connection" - await self._establish_connection(address, ssl, accept) + self.logger.debug("Stopping server.") + self._server.close() + await self._server.wait_closed() + self.logger.debug("Server stopped.") + finally: + self._server = None - phase = "session" - await self._establish_session() + @bottom_half # However, it does not run from the R/W tasks. + async def _incoming(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter) -> None: + """ + Accept an incoming connection and signal the upper_half. - except BaseException as err: - emsg = f"Failed to establish {phase}" - self.logger.error("%s: %s", emsg, exception_summary(err)) - self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) - try: - # Reset from CONNECTING back to IDLE. - await self.disconnect() - except: - emsg = "Unexpected bottom half exception" - self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) - raise + This method does the minimum necessary to accept a single + incoming connection. It signals back to the upper_half ASAP so + that any errors during session initialization can occur + naturally in the caller's stack. - # NB: CancelledError is not a BaseException before Python 3.8 - if isinstance(err, asyncio.CancelledError): - raise + :param reader: Incoming `asyncio.StreamReader` + :param writer: Incoming `asyncio.StreamWriter` + """ + peer = writer.get_extra_info('peername', 'Unknown peer') + self.logger.debug("Incoming connection from %s", peer) - if isinstance(err, Exception): - raise ConnectError(emsg, err) from err + if self._reader or self._writer: + # Sadly, we can have more than one pending connection + # because of https://bugs.python.org/issue46715 + # Close any extra connections we don't actually want. + self.logger.warning("Extraneous connection inadvertently accepted") + writer.close() + return - # Raise BaseExceptions un-wrapped, they're more important. - raise - - assert self.runstate == Runstate.RUNNING + # A connection has been accepted; stop listening for new ones. + assert self._accepted is not None + await self._stop_server() + self._reader, self._writer = (reader, writer) + self._accepted.set() @upper_half - async def _establish_connection( - self, - address: SocketAddrT, - ssl: Optional[SSLContext] = None, - accept: bool = False - ) -> None: + async def _do_start_server(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: """ - Establish a new connection. + Start listening for an incoming connection, but do not wait for a peer. - :param address: - Address to connect to/listen on; - UNIX socket path or TCP address/port. - :param ssl: SSL context to use, if any. - :param accept: Accept a connection instead of connecting when `True`. - """ - assert self.runstate == Runstate.IDLE - self._set_state(Runstate.CONNECTING) - - # Allow runstate watchers to witness 'CONNECTING' state; some - # failures in the streaming layer are synchronous and will not - # otherwise yield. - await asyncio.sleep(0) - - if accept: - await self._do_accept(address, ssl) - else: - await self._do_connect(address, ssl) - - def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None: - """ - Used to create a socket in advance of accept(). - - This is a workaround to ensure that we can guarantee timing of - precisely when a socket exists to avoid a connection attempt - bouncing off of nothing. - - Python 3.7+ adds a feature to separate the server creation and - listening phases instead, and should be used instead of this - hack. - """ - if isinstance(address, tuple): - family = socket.AF_INET - else: - family = socket.AF_UNIX - - sock = socket.socket(family, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - - try: - sock.bind(address) - except: - sock.close() - raise - - self._sock = sock - - @upper_half - async def _do_accept(self, address: SocketAddrT, - ssl: Optional[SSLContext] = None) -> None: - """ - Acting as the transport server, accept a single connection. + This method starts listening for an incoming connection, but does not + block waiting for a peer. This call will return immediately after + binding and listening to a socket. A later call to accept() must be + made in order to finalize the incoming connection. :param address: Address to listen on; UNIX socket path or TCP address/port. @@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]): :raise OSError: For stream-related errors. """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + self.logger.debug("Awaiting connection on %s ...", address) - connected = asyncio.Event() - server: Optional[asyncio.AbstractServer] = None - - async def _client_connected_cb(reader: asyncio.StreamReader, - writer: asyncio.StreamWriter) -> None: - """Used to accept a single incoming connection, see below.""" - nonlocal server - nonlocal connected - - # A connection has been accepted; stop listening for new ones. - assert server is not None - server.close() - await server.wait_closed() - server = None - - # Register this client as being connected - self._reader, self._writer = (reader, writer) - - # Signal back: We've accepted a client! - connected.set() + self._accepted = asyncio.Event() if isinstance(address, tuple): coro = asyncio.start_server( - _client_connected_cb, - host=None if self._sock else address[0], - port=None if self._sock else address[1], + self._incoming, + host=address[0], + port=address[1], ssl=ssl, backlog=1, limit=self._limit, - sock=self._sock, ) else: coro = asyncio.start_unix_server( - _client_connected_cb, - path=None if self._sock else address, + self._incoming, + path=address, ssl=ssl, backlog=1, limit=self._limit, - sock=self._sock, ) - server = await coro # Starts listening - await connected.wait() # Waits for the callback to fire (and finish) - assert server is None - self._sock = None + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + + # This will start the server (bind(2), listen(2)). It will also + # call accept(2) if we yield, but we don't block on that here. + self._server = await coro + self.logger.debug("Server listening on %s", address) + + @upper_half + async def _do_accept(self) -> None: + """ + Wait for and accept an incoming connection. + + Requires that we have not yet accepted an incoming connection + from the upper_half, but it's OK if the server is no longer + running because the bottom_half has already accepted the + connection. + """ + assert self._accepted is not None + await self._accepted.wait() + assert self._server is None + self._accepted = None self.logger.debug("Connection accepted.") @@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]): :raise OSError: For stream-related errors. """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + self.logger.debug("Connecting to %s ...", address) if isinstance(address, tuple): @@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]): self._reader = None self._writer = None + self._accepted = None # NB: _runstate_changed cannot be cleared because we still need it to # send the final runstate changed event ...! @@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]): def _done(task: Optional['asyncio.Future[Any]']) -> bool: return task is not None and task.done() + # If the server is running, stop it. + await self._stop_server() + # Are we already in an error pathway? If either of the tasks are # already done, or if we have no tasks but a reader/writer; we # must be. diff --git a/python/tests/protocol.py b/python/tests/protocol.py index 5cd7938be3..d6849ad306 100644 --- a/python/tests/protocol.py +++ b/python/tests/protocol.py @@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]): self.trigger_input = asyncio.Event() await super()._establish_session() - async def _do_accept(self, address, ssl=None): - if not self.fake_session: - await super()._do_accept(address, ssl) + async def _do_start_server(self, address, ssl=None): + if self.fake_session: + self._accepted = asyncio.Event() + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: + await super()._do_start_server(address, ssl) + + async def _do_accept(self): + if self.fake_session: + self._accepted = None + else: + await super()._do_accept() async def _do_connect(self, address, ssl=None): - if not self.fake_session: + if self.fake_session: + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: await super()._do_connect(address, ssl) async def _do_recv(self) -> None: @@ -413,14 +426,14 @@ class Accept(Connect): assert family in ('INET', 'UNIX') if family == 'INET': - await self.proto.accept(('example.com', 1)) + await self.proto.start_server_and_accept(('example.com', 1)) elif family == 'UNIX': - await self.proto.accept('/dev/null') + await self.proto.start_server_and_accept('/dev/null') async def _hanging_connection(self): with TemporaryDirectory(suffix='.aqmp') as tmpdir: sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") - await self.proto.accept(sock) + await self.proto.start_server_and_accept(sock) class FakeSession(TestBase): @@ -449,13 +462,13 @@ class FakeSession(TestBase): @TestBase.async_test async def testFakeAccept(self): """Test the full state lifecycle (via accept) with a no-op session.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') self.assertEqual(self.proto.runstate, Runstate.RUNNING) @TestBase.async_test async def testFakeRecv(self): """Test receiving a fake/null message.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') logname = self.proto.logger.name with self.assertLogs(logname, level='DEBUG') as context: @@ -471,7 +484,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testFakeSend(self): """Test sending a fake/null message.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') logname = self.proto.logger.name with self.assertLogs(logname, level='DEBUG') as context: @@ -493,7 +506,7 @@ class FakeSession(TestBase): ): with self.assertRaises(StateError) as context: if accept: - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') else: await self.proto.connect('/not/a/real/path') @@ -504,7 +517,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testAcceptRequireRunning(self): """Test that accept() cannot be called when Runstate=RUNNING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') await self._prod_session_api( Runstate.RUNNING, @@ -515,7 +528,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testConnectRequireRunning(self): """Test that connect() cannot be called when Runstate=RUNNING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') await self._prod_session_api( Runstate.RUNNING, @@ -526,7 +539,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testAcceptRequireDisconnecting(self): """Test that accept() cannot be called when Runstate=DISCONNECTING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') # Cheat: force a disconnect. await self.proto.simulate_disconnect() @@ -541,7 +554,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testConnectRequireDisconnecting(self): """Test that connect() cannot be called when Runstate=DISCONNECTING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') # Cheat: force a disconnect. await self.proto.simulate_disconnect() @@ -576,7 +589,7 @@ class SimpleSession(TestBase): async def testSmoke(self): with TemporaryDirectory(suffix='.aqmp') as tmpdir: sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") - server_task = create_task(self.server.accept(sock)) + server_task = create_task(self.server.start_server_and_accept(sock)) # give the server a chance to start listening [...] await asyncio.sleep(0) diff --git a/scripts/qmp/qmp-shell-wrap b/scripts/qmp/qmp-shell-wrap index 9e94da114f..66846e36d1 100755 --- a/scripts/qmp/qmp-shell-wrap +++ b/scripts/qmp/qmp-shell-wrap @@ -4,7 +4,7 @@ import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python')) -from qemu.qmp import qmp_shell +from qemu.aqmp import qmp_shell if __name__ == '__main__':