Python patches
Hopefully, fixes the race conditions witnessed through the NetBSD vm tests. -----BEGIN PGP SIGNATURE----- iQIzBAABCAAdFiEE+ber27ys35W+dsvQfe+BBqr8OQ4FAmImg9IACgkQfe+BBqr8 OQ6pMxAAgilUH8OIJzJfV2C/1qWM2Hzrl/jwTUEuYxmMYacdL9kJvR3NJ4CMv5Nn 996TyJROK+QDQoVsUuoEjkdrezbI4UDoixM9ku7KWAUMEsxXmRR5kcclSkCWX4HX o+My1UR+6LxPgH894JMTcnKzH9gDHkU0Aww/nu5LumJoVB12Gu1iLif/2JneQKFB rWaQu+8DHGH7Jv9s0ShrmkDYwtwq5XXGtefR6DEdo5xGGCjzYrYr80Frg7R1OYVU xlGV0MbLjTmePM5F4ZxiQGohFSOY6QsraxDMiqVOc+gBjz2J8l+7i8AA3Zirwotz V9BYPDRZ9pZV3ERDPqh0L3homsmk2wepkXi6YAz9/DMn0pDHizmvntPCCdhzBXyH cA63+QayvCYADDoHkUbMT5jc7X6ayfauj7ZkJPzfr7YtzYKs6k0bDmtgJBMyNRj1 pHILnv5oGnnVz4kO5W98oV2jijAdqi9or3+4B2woeUmaROoQJA0ObU35ke961KNE n66kTOibgMj/TQmDE1veBgNvCxY0cRE+ZB7SYL7ZaqvavEwfeYQRz851sDxTdiFF v5b/Ls8IDKPbU8qPLDzTQrAy19CWtOkJTD4b4/6WAv9K0SAxghQEyoCUCZbk+PLt xGeCyxImTC7XaqFlops9WzBTK3jz/7m9EvgfJNRKj8QZ49yxCBo= =0ieN -----END PGP SIGNATURE----- Merge remote-tracking branch 'remotes/jsnow-gitlab/tags/python-pull-request' into staging Python patches Hopefully, fixes the race conditions witnessed through the NetBSD vm tests. # gpg: Signature made Mon 07 Mar 2022 22:14:42 GMT # gpg: using RSA key F9B7ABDBBCACDF95BE76CBD07DEF8106AAFC390E # gpg: Good signature from "John Snow (John Huston) <jsnow@redhat.com>" [full] # Primary key fingerprint: FAEB 9711 A12C F475 812F 18F2 88A9 064D 1835 61EB # Subkey fingerprint: F9B7 ABDB BCAC DF95 BE76 CBD0 7DEF 8106 AAFC 390E * remotes/jsnow-gitlab/tags/python-pull-request: scripts/qmp-shell-wrap: Fix import path python/aqmp: drop _bind_hack() python/aqmp: fix race condition in legacy.py python/aqmp: add start_server() and accept() methods python/aqmp: stop the server during disconnect() python/aqmp: refactor _do_accept() into two distinct steps python/aqmp: squelch pylint warning for too many lines python/aqmp: split _client_connected_cb() out as _incoming() python/aqmp: remove _new_session and _establish_connection python/aqmp: rename 'accept()' to 'start_server_and_accept()' python/aqmp: add _session_guard() Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
This commit is contained in:
commit
2ad7624900
@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
|
||||
self._timeout: Optional[float] = None
|
||||
|
||||
if server:
|
||||
self._aqmp._bind_hack(address) # pylint: disable=protected-access
|
||||
self._sync(self._aqmp.start_server(self._address))
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
|
||||
self._aqmp.await_greeting = True
|
||||
self._aqmp.negotiate = True
|
||||
|
||||
self._sync(
|
||||
self._aqmp.accept(self._address),
|
||||
timeout
|
||||
)
|
||||
self._sync(self._aqmp.accept(), timeout)
|
||||
|
||||
ret = self._get_greeting()
|
||||
assert ret is not None
|
||||
|
@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient`
|
||||
class.
|
||||
"""
|
||||
|
||||
# It's all the docstrings ... ! It's long for a good reason ^_^;
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import asyncio
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
import logging
|
||||
import socket
|
||||
from ssl import SSLContext
|
||||
from typing import (
|
||||
Any,
|
||||
@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]):
|
||||
self._runstate = Runstate.IDLE
|
||||
self._runstate_changed: Optional[asyncio.Event] = None
|
||||
|
||||
# Workaround for bind()
|
||||
self._sock: Optional[socket.socket] = None
|
||||
# Server state for start_server() and _incoming()
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
self._accepted: Optional[asyncio.Event] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls_name = type(self).__name__
|
||||
@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]):
|
||||
|
||||
@upper_half
|
||||
@require(Runstate.IDLE)
|
||||
async def accept(self, address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None) -> None:
|
||||
async def start_server_and_accept(
|
||||
self, address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None
|
||||
) -> None:
|
||||
"""
|
||||
Accept a connection and begin processing message queues.
|
||||
|
||||
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
|
||||
This method is precisely equivalent to calling `start_server()`
|
||||
followed by `accept()`.
|
||||
|
||||
:param address:
|
||||
Address to listen to; UNIX socket path or TCP address/port.
|
||||
Address to listen on; UNIX socket path or TCP address/port.
|
||||
:param ssl: SSL context to use, if any.
|
||||
|
||||
:raise StateError: When the `Runstate` is not `IDLE`.
|
||||
:raise ConnectError: If a connection could not be accepted.
|
||||
:raise ConnectError:
|
||||
When a connection or session cannot be established.
|
||||
|
||||
This exception will wrap a more concrete one. In most cases,
|
||||
the wrapped exception will be `OSError` or `EOFError`. If a
|
||||
protocol-level failure occurs while establishing a new
|
||||
session, the wrapped error may also be an `QMPError`.
|
||||
"""
|
||||
await self._new_session(address, ssl, accept=True)
|
||||
await self.start_server(address, ssl)
|
||||
await self.accept()
|
||||
assert self.runstate == Runstate.RUNNING
|
||||
|
||||
@upper_half
|
||||
@require(Runstate.IDLE)
|
||||
async def start_server(self, address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None) -> None:
|
||||
"""
|
||||
Start listening for an incoming connection, but do not wait for a peer.
|
||||
|
||||
This method starts listening for an incoming connection, but
|
||||
does not block waiting for a peer. This call will return
|
||||
immediately after binding and listening on a socket. A later
|
||||
call to `accept()` must be made in order to finalize the
|
||||
incoming connection.
|
||||
|
||||
:param address:
|
||||
Address to listen on; UNIX socket path or TCP address/port.
|
||||
:param ssl: SSL context to use, if any.
|
||||
|
||||
:raise StateError: When the `Runstate` is not `IDLE`.
|
||||
:raise ConnectError:
|
||||
When the server could not start listening on this address.
|
||||
|
||||
This exception will wrap a more concrete one. In most cases,
|
||||
the wrapped exception will be `OSError`.
|
||||
"""
|
||||
await self._session_guard(
|
||||
self._do_start_server(address, ssl),
|
||||
'Failed to establish connection')
|
||||
assert self.runstate == Runstate.CONNECTING
|
||||
|
||||
@upper_half
|
||||
@require(Runstate.CONNECTING)
|
||||
async def accept(self) -> None:
|
||||
"""
|
||||
Accept an incoming connection and begin processing message queues.
|
||||
|
||||
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
|
||||
|
||||
:raise StateError: When the `Runstate` is not `CONNECTING`.
|
||||
:raise QMPError: When `start_server()` was not called yet.
|
||||
:raise ConnectError:
|
||||
When a connection or session cannot be established.
|
||||
|
||||
This exception will wrap a more concrete one. In most cases,
|
||||
the wrapped exception will be `OSError` or `EOFError`. If a
|
||||
protocol-level failure occurs while establishing a new
|
||||
session, the wrapped error may also be an `QMPError`.
|
||||
"""
|
||||
if self._accepted is None:
|
||||
raise QMPError("Cannot call accept() before start_server().")
|
||||
await self._session_guard(
|
||||
self._do_accept(),
|
||||
'Failed to establish connection')
|
||||
await self._session_guard(
|
||||
self._establish_session(),
|
||||
'Failed to establish session')
|
||||
assert self.runstate == Runstate.RUNNING
|
||||
|
||||
@upper_half
|
||||
@require(Runstate.IDLE)
|
||||
@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]):
|
||||
:param ssl: SSL context to use, if any.
|
||||
|
||||
:raise StateError: When the `Runstate` is not `IDLE`.
|
||||
:raise ConnectError: If a connection cannot be made to the server.
|
||||
:raise ConnectError:
|
||||
When a connection or session cannot be established.
|
||||
|
||||
This exception will wrap a more concrete one. In most cases,
|
||||
the wrapped exception will be `OSError` or `EOFError`. If a
|
||||
protocol-level failure occurs while establishing a new
|
||||
session, the wrapped error may also be an `QMPError`.
|
||||
"""
|
||||
await self._new_session(address, ssl)
|
||||
await self._session_guard(
|
||||
self._do_connect(address, ssl),
|
||||
'Failed to establish connection')
|
||||
await self._session_guard(
|
||||
self._establish_session(),
|
||||
'Failed to establish session')
|
||||
assert self.runstate == Runstate.RUNNING
|
||||
|
||||
@upper_half
|
||||
async def disconnect(self) -> None:
|
||||
@ -317,6 +401,62 @@ class AsyncProtocol(Generic[T]):
|
||||
# Section: Session machinery
|
||||
# --------------------------
|
||||
|
||||
async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
|
||||
"""
|
||||
Async guard function used to roll back to `IDLE` on any error.
|
||||
|
||||
On any Exception, the state machine will be reset back to
|
||||
`IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
|
||||
`BaseException` events will be left alone (This includes
|
||||
asyncio.CancelledError, even prior to Python 3.8).
|
||||
|
||||
:param error_message:
|
||||
Human-readable string describing what connection phase failed.
|
||||
|
||||
:raise BaseException:
|
||||
When `BaseException` occurs in the guarded block.
|
||||
:raise ConnectError:
|
||||
When any other error is encountered in the guarded block.
|
||||
"""
|
||||
# Note: After Python 3.6 support is removed, this should be an
|
||||
# @asynccontextmanager instead of accepting a callback.
|
||||
try:
|
||||
await coro
|
||||
except BaseException as err:
|
||||
self.logger.error("%s: %s", emsg, exception_summary(err))
|
||||
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
|
||||
try:
|
||||
# Reset the runstate back to IDLE.
|
||||
await self.disconnect()
|
||||
except:
|
||||
# We don't expect any Exceptions from the disconnect function
|
||||
# here, because we failed to connect in the first place.
|
||||
# The disconnect() function is intended to perform
|
||||
# only cannot-fail cleanup here, but you never know.
|
||||
emsg = (
|
||||
"Unexpected bottom half exception. "
|
||||
"This is a bug in the QMP library. "
|
||||
"Please report it to <qemu-devel@nongnu.org> and "
|
||||
"CC: John Snow <jsnow@redhat.com>."
|
||||
)
|
||||
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
|
||||
raise
|
||||
|
||||
# CancelledError is an Exception with special semantic meaning;
|
||||
# We do NOT want to wrap it up under ConnectError.
|
||||
# NB: CancelledError is not a BaseException before Python 3.8
|
||||
if isinstance(err, asyncio.CancelledError):
|
||||
raise
|
||||
|
||||
# Any other kind of error can be treated as some kind of connection
|
||||
# failure broadly. Inspect the 'exc' field to explore the root
|
||||
# cause in greater detail.
|
||||
if isinstance(err, Exception):
|
||||
raise ConnectError(emsg, err) from err
|
||||
|
||||
# Raise BaseExceptions un-wrapped, they're more important.
|
||||
raise
|
||||
|
||||
@property
|
||||
def _runstate_event(self) -> asyncio.Event:
|
||||
# asyncio.Event() objects should not be created prior to entrance into
|
||||
@ -343,127 +483,64 @@ class AsyncProtocol(Generic[T]):
|
||||
self._runstate_event.set()
|
||||
self._runstate_event.clear()
|
||||
|
||||
@upper_half
|
||||
async def _new_session(self,
|
||||
address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None,
|
||||
accept: bool = False) -> None:
|
||||
@bottom_half
|
||||
async def _stop_server(self) -> None:
|
||||
"""
|
||||
Establish a new connection and initialize the session.
|
||||
|
||||
Connect or accept a new connection, then begin the protocol
|
||||
session machinery. If this call fails, `runstate` is guaranteed
|
||||
to be set back to `IDLE`.
|
||||
|
||||
:param address:
|
||||
Address to connect to/listen on;
|
||||
UNIX socket path or TCP address/port.
|
||||
:param ssl: SSL context to use, if any.
|
||||
:param accept: Accept a connection instead of connecting when `True`.
|
||||
|
||||
:raise ConnectError:
|
||||
When a connection or session cannot be established.
|
||||
|
||||
This exception will wrap a more concrete one. In most cases,
|
||||
the wrapped exception will be `OSError` or `EOFError`. If a
|
||||
protocol-level failure occurs while establishing a new
|
||||
session, the wrapped error may also be an `QMPError`.
|
||||
Stop listening for / accepting new incoming connections.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
if self._server is None:
|
||||
return
|
||||
|
||||
try:
|
||||
phase = "connection"
|
||||
await self._establish_connection(address, ssl, accept)
|
||||
self.logger.debug("Stopping server.")
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self.logger.debug("Server stopped.")
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
phase = "session"
|
||||
await self._establish_session()
|
||||
@bottom_half # However, it does not run from the R/W tasks.
|
||||
async def _incoming(self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter) -> None:
|
||||
"""
|
||||
Accept an incoming connection and signal the upper_half.
|
||||
|
||||
except BaseException as err:
|
||||
emsg = f"Failed to establish {phase}"
|
||||
self.logger.error("%s: %s", emsg, exception_summary(err))
|
||||
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
|
||||
try:
|
||||
# Reset from CONNECTING back to IDLE.
|
||||
await self.disconnect()
|
||||
except:
|
||||
emsg = "Unexpected bottom half exception"
|
||||
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
|
||||
raise
|
||||
This method does the minimum necessary to accept a single
|
||||
incoming connection. It signals back to the upper_half ASAP so
|
||||
that any errors during session initialization can occur
|
||||
naturally in the caller's stack.
|
||||
|
||||
# NB: CancelledError is not a BaseException before Python 3.8
|
||||
if isinstance(err, asyncio.CancelledError):
|
||||
raise
|
||||
:param reader: Incoming `asyncio.StreamReader`
|
||||
:param writer: Incoming `asyncio.StreamWriter`
|
||||
"""
|
||||
peer = writer.get_extra_info('peername', 'Unknown peer')
|
||||
self.logger.debug("Incoming connection from %s", peer)
|
||||
|
||||
if isinstance(err, Exception):
|
||||
raise ConnectError(emsg, err) from err
|
||||
if self._reader or self._writer:
|
||||
# Sadly, we can have more than one pending connection
|
||||
# because of https://bugs.python.org/issue46715
|
||||
# Close any extra connections we don't actually want.
|
||||
self.logger.warning("Extraneous connection inadvertently accepted")
|
||||
writer.close()
|
||||
return
|
||||
|
||||
# Raise BaseExceptions un-wrapped, they're more important.
|
||||
raise
|
||||
|
||||
assert self.runstate == Runstate.RUNNING
|
||||
# A connection has been accepted; stop listening for new ones.
|
||||
assert self._accepted is not None
|
||||
await self._stop_server()
|
||||
self._reader, self._writer = (reader, writer)
|
||||
self._accepted.set()
|
||||
|
||||
@upper_half
|
||||
async def _establish_connection(
|
||||
self,
|
||||
address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None,
|
||||
accept: bool = False
|
||||
) -> None:
|
||||
async def _do_start_server(self, address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None) -> None:
|
||||
"""
|
||||
Establish a new connection.
|
||||
Start listening for an incoming connection, but do not wait for a peer.
|
||||
|
||||
:param address:
|
||||
Address to connect to/listen on;
|
||||
UNIX socket path or TCP address/port.
|
||||
:param ssl: SSL context to use, if any.
|
||||
:param accept: Accept a connection instead of connecting when `True`.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
|
||||
# Allow runstate watchers to witness 'CONNECTING' state; some
|
||||
# failures in the streaming layer are synchronous and will not
|
||||
# otherwise yield.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if accept:
|
||||
await self._do_accept(address, ssl)
|
||||
else:
|
||||
await self._do_connect(address, ssl)
|
||||
|
||||
def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
|
||||
"""
|
||||
Used to create a socket in advance of accept().
|
||||
|
||||
This is a workaround to ensure that we can guarantee timing of
|
||||
precisely when a socket exists to avoid a connection attempt
|
||||
bouncing off of nothing.
|
||||
|
||||
Python 3.7+ adds a feature to separate the server creation and
|
||||
listening phases instead, and should be used instead of this
|
||||
hack.
|
||||
"""
|
||||
if isinstance(address, tuple):
|
||||
family = socket.AF_INET
|
||||
else:
|
||||
family = socket.AF_UNIX
|
||||
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
try:
|
||||
sock.bind(address)
|
||||
except:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
self._sock = sock
|
||||
|
||||
@upper_half
|
||||
async def _do_accept(self, address: SocketAddrT,
|
||||
ssl: Optional[SSLContext] = None) -> None:
|
||||
"""
|
||||
Acting as the transport server, accept a single connection.
|
||||
This method starts listening for an incoming connection, but does not
|
||||
block waiting for a peer. This call will return immediately after
|
||||
binding and listening to a socket. A later call to accept() must be
|
||||
made in order to finalize the incoming connection.
|
||||
|
||||
:param address:
|
||||
Address to listen on; UNIX socket path or TCP address/port.
|
||||
@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]):
|
||||
|
||||
:raise OSError: For stream-related errors.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
|
||||
self.logger.debug("Awaiting connection on %s ...", address)
|
||||
connected = asyncio.Event()
|
||||
server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def _client_connected_cb(reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter) -> None:
|
||||
"""Used to accept a single incoming connection, see below."""
|
||||
nonlocal server
|
||||
nonlocal connected
|
||||
|
||||
# A connection has been accepted; stop listening for new ones.
|
||||
assert server is not None
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
server = None
|
||||
|
||||
# Register this client as being connected
|
||||
self._reader, self._writer = (reader, writer)
|
||||
|
||||
# Signal back: We've accepted a client!
|
||||
connected.set()
|
||||
self._accepted = asyncio.Event()
|
||||
|
||||
if isinstance(address, tuple):
|
||||
coro = asyncio.start_server(
|
||||
_client_connected_cb,
|
||||
host=None if self._sock else address[0],
|
||||
port=None if self._sock else address[1],
|
||||
self._incoming,
|
||||
host=address[0],
|
||||
port=address[1],
|
||||
ssl=ssl,
|
||||
backlog=1,
|
||||
limit=self._limit,
|
||||
sock=self._sock,
|
||||
)
|
||||
else:
|
||||
coro = asyncio.start_unix_server(
|
||||
_client_connected_cb,
|
||||
path=None if self._sock else address,
|
||||
self._incoming,
|
||||
path=address,
|
||||
ssl=ssl,
|
||||
backlog=1,
|
||||
limit=self._limit,
|
||||
sock=self._sock,
|
||||
)
|
||||
|
||||
server = await coro # Starts listening
|
||||
await connected.wait() # Waits for the callback to fire (and finish)
|
||||
assert server is None
|
||||
self._sock = None
|
||||
# Allow runstate watchers to witness 'CONNECTING' state; some
|
||||
# failures in the streaming layer are synchronous and will not
|
||||
# otherwise yield.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# This will start the server (bind(2), listen(2)). It will also
|
||||
# call accept(2) if we yield, but we don't block on that here.
|
||||
self._server = await coro
|
||||
self.logger.debug("Server listening on %s", address)
|
||||
|
||||
@upper_half
|
||||
async def _do_accept(self) -> None:
|
||||
"""
|
||||
Wait for and accept an incoming connection.
|
||||
|
||||
Requires that we have not yet accepted an incoming connection
|
||||
from the upper_half, but it's OK if the server is no longer
|
||||
running because the bottom_half has already accepted the
|
||||
connection.
|
||||
"""
|
||||
assert self._accepted is not None
|
||||
await self._accepted.wait()
|
||||
assert self._server is None
|
||||
self._accepted = None
|
||||
|
||||
self.logger.debug("Connection accepted.")
|
||||
|
||||
@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]):
|
||||
|
||||
:raise OSError: For stream-related errors.
|
||||
"""
|
||||
assert self.runstate == Runstate.IDLE
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
|
||||
# Allow runstate watchers to witness 'CONNECTING' state; some
|
||||
# failures in the streaming layer are synchronous and will not
|
||||
# otherwise yield.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.logger.debug("Connecting to %s ...", address)
|
||||
|
||||
if isinstance(address, tuple):
|
||||
@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]):
|
||||
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._accepted = None
|
||||
|
||||
# NB: _runstate_changed cannot be cleared because we still need it to
|
||||
# send the final runstate changed event ...!
|
||||
@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]):
|
||||
def _done(task: Optional['asyncio.Future[Any]']) -> bool:
|
||||
return task is not None and task.done()
|
||||
|
||||
# If the server is running, stop it.
|
||||
await self._stop_server()
|
||||
|
||||
# Are we already in an error pathway? If either of the tasks are
|
||||
# already done, or if we have no tasks but a reader/writer; we
|
||||
# must be.
|
||||
|
@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]):
|
||||
self.trigger_input = asyncio.Event()
|
||||
await super()._establish_session()
|
||||
|
||||
async def _do_accept(self, address, ssl=None):
|
||||
if not self.fake_session:
|
||||
await super()._do_accept(address, ssl)
|
||||
async def _do_start_server(self, address, ssl=None):
|
||||
if self.fake_session:
|
||||
self._accepted = asyncio.Event()
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
await super()._do_start_server(address, ssl)
|
||||
|
||||
async def _do_accept(self):
|
||||
if self.fake_session:
|
||||
self._accepted = None
|
||||
else:
|
||||
await super()._do_accept()
|
||||
|
||||
async def _do_connect(self, address, ssl=None):
|
||||
if not self.fake_session:
|
||||
if self.fake_session:
|
||||
self._set_state(Runstate.CONNECTING)
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
await super()._do_connect(address, ssl)
|
||||
|
||||
async def _do_recv(self) -> None:
|
||||
@ -413,14 +426,14 @@ class Accept(Connect):
|
||||
assert family in ('INET', 'UNIX')
|
||||
|
||||
if family == 'INET':
|
||||
await self.proto.accept(('example.com', 1))
|
||||
await self.proto.start_server_and_accept(('example.com', 1))
|
||||
elif family == 'UNIX':
|
||||
await self.proto.accept('/dev/null')
|
||||
await self.proto.start_server_and_accept('/dev/null')
|
||||
|
||||
async def _hanging_connection(self):
|
||||
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
|
||||
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
|
||||
await self.proto.accept(sock)
|
||||
await self.proto.start_server_and_accept(sock)
|
||||
|
||||
|
||||
class FakeSession(TestBase):
|
||||
@ -449,13 +462,13 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testFakeAccept(self):
|
||||
"""Test the full state lifecycle (via accept) with a no-op session."""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
self.assertEqual(self.proto.runstate, Runstate.RUNNING)
|
||||
|
||||
@TestBase.async_test
|
||||
async def testFakeRecv(self):
|
||||
"""Test receiving a fake/null message."""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
logname = self.proto.logger.name
|
||||
with self.assertLogs(logname, level='DEBUG') as context:
|
||||
@ -471,7 +484,7 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testFakeSend(self):
|
||||
"""Test sending a fake/null message."""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
logname = self.proto.logger.name
|
||||
with self.assertLogs(logname, level='DEBUG') as context:
|
||||
@ -493,7 +506,7 @@ class FakeSession(TestBase):
|
||||
):
|
||||
with self.assertRaises(StateError) as context:
|
||||
if accept:
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
else:
|
||||
await self.proto.connect('/not/a/real/path')
|
||||
|
||||
@ -504,7 +517,7 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testAcceptRequireRunning(self):
|
||||
"""Test that accept() cannot be called when Runstate=RUNNING"""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
await self._prod_session_api(
|
||||
Runstate.RUNNING,
|
||||
@ -515,7 +528,7 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testConnectRequireRunning(self):
|
||||
"""Test that connect() cannot be called when Runstate=RUNNING"""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
await self._prod_session_api(
|
||||
Runstate.RUNNING,
|
||||
@ -526,7 +539,7 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testAcceptRequireDisconnecting(self):
|
||||
"""Test that accept() cannot be called when Runstate=DISCONNECTING"""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
# Cheat: force a disconnect.
|
||||
await self.proto.simulate_disconnect()
|
||||
@ -541,7 +554,7 @@ class FakeSession(TestBase):
|
||||
@TestBase.async_test
|
||||
async def testConnectRequireDisconnecting(self):
|
||||
"""Test that connect() cannot be called when Runstate=DISCONNECTING"""
|
||||
await self.proto.accept('/not/a/real/path')
|
||||
await self.proto.start_server_and_accept('/not/a/real/path')
|
||||
|
||||
# Cheat: force a disconnect.
|
||||
await self.proto.simulate_disconnect()
|
||||
@ -576,7 +589,7 @@ class SimpleSession(TestBase):
|
||||
async def testSmoke(self):
|
||||
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
|
||||
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
|
||||
server_task = create_task(self.server.accept(sock))
|
||||
server_task = create_task(self.server.start_server_and_accept(sock))
|
||||
|
||||
# give the server a chance to start listening [...]
|
||||
await asyncio.sleep(0)
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python'))
|
||||
from qemu.qmp import qmp_shell
|
||||
from qemu.aqmp import qmp_shell
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user