python/aqmp: Fix disconnect during capabilities negotiation

If we receive ConnectionResetError (ECONNRESET) while attempting to
perform capabilities negotiation -- prior to the establishment of the
async reader/writer tasks -- the disconnect function is not aware that
we are in an error pathway.

As a result, when attempting to close the StreamWriter, we'll see the
same ConnectionResetError that caused us to initiate a disconnect in the
first place, which will cause the disconnect task itself to fail, which
emits a CRITICAL logging event.

I still don't know if there's a smarter way to check to see if an
exception received at this point is "the same" exception as the one that
caused the initial disconnect, but for now the problem can be avoided by
improving the error pathway detection in the exit path.

Reported-by: Thomas Huth <thuth@redhat.com>
Signed-off-by: John Snow <jsnow@redhat.com>
Tested-by: Thomas Huth <thuth@redhat.com>
Message-id: 20211111143719.2162525-2-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
This commit is contained in:
John Snow 2021-11-11 09:37:15 -05:00
parent 2b22e7540d
commit f26bd6ff21

View File

@ -623,13 +623,21 @@ class AsyncProtocol(Generic[T]):
def _done(task: Optional['asyncio.Future[Any]']) -> bool:
return task is not None and task.done()
# NB: We can't rely on _bh_tasks being done() here, it may not
# yet have had a chance to run and gather itself.
# Are we already in an error pathway? If either of the tasks are
# already done, or if we have no tasks but a reader/writer; we
# must be.
#
# NB: We can't use _bh_tasks to check for premature task
# completion, because it may not yet have had a chance to run
# and gather itself.
tasks = tuple(filter(None, (self._writer_task, self._reader_task)))
error_pathway = _done(self._reader_task) or _done(self._writer_task)
if not tasks:
error_pathway |= bool(self._reader) or bool(self._writer)
try:
# Try to flush the writer, if possible:
# Try to flush the writer, if possible.
# This *may* cause an error and force us over into the error path.
if not error_pathway:
await self._bh_flush_writer()
except BaseException as err:
@ -639,7 +647,7 @@ class AsyncProtocol(Generic[T]):
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
raise
finally:
# Cancel any still-running tasks:
# Cancel any still-running tasks (Won't raise):
if self._writer_task is not None and not self._writer_task.done():
self.logger.debug("Cancelling writer task.")
self._writer_task.cancel()
@ -652,7 +660,7 @@ class AsyncProtocol(Generic[T]):
self.logger.debug("Waiting for tasks to complete ...")
await asyncio.wait(tasks)
# Lastly, close the stream itself. (May raise):
# Lastly, close the stream itself. (*May raise*!):
await self._bh_close_stream(error_pathway)
self.logger.debug("Disconnected.")