extmod/asyncio: Support gather of tasks that finish early.
Adds support to asyncio.gather() for the case that one or more (or all) sub-tasks finish and/or raise an exception before the gather starts. Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
parent
1e8cc6c503
commit
a2e9ab362b
@ -219,6 +219,11 @@ def run_until_complete(main_task=None):
|
|||||||
elif t.state is None:
|
elif t.state is None:
|
||||||
# Task is already finished and nothing await'ed on the task,
|
# Task is already finished and nothing await'ed on the task,
|
||||||
# so call the exception handler.
|
# so call the exception handler.
|
||||||
|
|
||||||
|
# Save exception raised by the coro for later use.
|
||||||
|
t.data = exc
|
||||||
|
|
||||||
|
# Create exception context and call the exception handler.
|
||||||
_exc_context["exception"] = exc
|
_exc_context["exception"] = exc
|
||||||
_exc_context["future"] = t
|
_exc_context["future"] = t
|
||||||
Loop.call_exception_handler(_exc_context)
|
Loop.call_exception_handler(_exc_context)
|
||||||
|
@ -63,9 +63,6 @@ class _Remove:
|
|||||||
|
|
||||||
# async
|
# async
|
||||||
def gather(*aws, return_exceptions=False):
|
def gather(*aws, return_exceptions=False):
|
||||||
if not aws:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def done(t, er):
|
def done(t, er):
|
||||||
# Sub-task "t" has finished, with exception "er".
|
# Sub-task "t" has finished, with exception "er".
|
||||||
nonlocal state
|
nonlocal state
|
||||||
@ -86,20 +83,33 @@ def gather(*aws, return_exceptions=False):
|
|||||||
# Gather waiting is done, schedule the main gather task.
|
# Gather waiting is done, schedule the main gather task.
|
||||||
core._task_queue.push(gather_task)
|
core._task_queue.push(gather_task)
|
||||||
|
|
||||||
|
# Prepare the sub-tasks for the gather.
|
||||||
|
# The `state` variable counts the number of tasks to wait for, and can be negative
|
||||||
|
# if the gather should not run at all (because a task already had an exception).
|
||||||
ts = [core._promote_to_task(aw) for aw in aws]
|
ts = [core._promote_to_task(aw) for aw in aws]
|
||||||
|
state = 0
|
||||||
for i in range(len(ts)):
|
for i in range(len(ts)):
|
||||||
if ts[i].state is not True:
|
if ts[i].state is True:
|
||||||
# Task is not running, gather not currently supported for this case.
|
# Task is running, register the callback to call when the task is done.
|
||||||
raise RuntimeError("can't gather")
|
|
||||||
# Register the callback to call when the task is done.
|
|
||||||
ts[i].state = done
|
ts[i].state = done
|
||||||
|
state += 1
|
||||||
|
elif not ts[i].state:
|
||||||
|
# Task finished already.
|
||||||
|
if not isinstance(ts[i].data, StopIteration):
|
||||||
|
# Task finished by raising an exception.
|
||||||
|
if not return_exceptions:
|
||||||
|
# Do not run this gather at all.
|
||||||
|
state = -len(ts)
|
||||||
|
else:
|
||||||
|
# Task being waited on, gather not currently supported for this case.
|
||||||
|
raise RuntimeError("can't gather")
|
||||||
|
|
||||||
# Set the state for execution of the gather.
|
# Set the state for execution of the gather.
|
||||||
gather_task = core.cur_task
|
gather_task = core.cur_task
|
||||||
state = len(ts)
|
|
||||||
cancel_all = False
|
cancel_all = False
|
||||||
|
|
||||||
# Wait for the a sub-task to need attention.
|
# Wait for a sub-task to need attention (if there are any to wait for).
|
||||||
|
if state > 0:
|
||||||
gather_task.data = _Remove
|
gather_task.data = _Remove
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
@ -118,8 +128,13 @@ def gather(*aws, return_exceptions=False):
|
|||||||
# Sub-task ran to completion, get its return value.
|
# Sub-task ran to completion, get its return value.
|
||||||
ts[i] = ts[i].data.value
|
ts[i] = ts[i].data.value
|
||||||
else:
|
else:
|
||||||
# Sub-task had an exception with return_exceptions==True, so get its exception.
|
# Sub-task had an exception.
|
||||||
|
if return_exceptions:
|
||||||
|
# Get the sub-task exception to return in the list of return values.
|
||||||
ts[i] = ts[i].data
|
ts[i] = ts[i].data
|
||||||
|
elif isinstance(state, int):
|
||||||
|
# Raise the sub-task exception, if there is not already an exception to raise.
|
||||||
|
state = ts[i].data
|
||||||
|
|
||||||
# Either this gather was cancelled, or one of the sub-tasks raised an exception with
|
# Either this gather was cancelled, or one of the sub-tasks raised an exception with
|
||||||
# return_exceptions==False, so reraise the exception here.
|
# return_exceptions==False, so reraise the exception here.
|
||||||
|
65
tests/extmod/asyncio_gather_finished_early.py
Normal file
65
tests/extmod/asyncio_gather_finished_early.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Test asyncio.gather() when a task is already finished before the gather starts.
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
except ImportError:
|
||||||
|
print("SKIP")
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
|
|
||||||
|
# CPython and MicroPython differ in when they signal (and print) that a task raised an
|
||||||
|
# uncaught exception. So define an empty custom_handler() to suppress this output.
|
||||||
|
def custom_handler(loop, context):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def task_that_finishes_early(id, event, fail):
|
||||||
|
print("task_that_finishes_early", id)
|
||||||
|
event.set()
|
||||||
|
if fail:
|
||||||
|
raise ValueError("intentional exception", id)
|
||||||
|
|
||||||
|
|
||||||
|
async def task_that_runs():
|
||||||
|
for i in range(5):
|
||||||
|
print("task_that_runs", i)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(start_task_that_runs, task_fail, return_exceptions):
|
||||||
|
print("== start", start_task_that_runs, task_fail, return_exceptions)
|
||||||
|
|
||||||
|
# Set exception handler to suppress exception output.
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.set_exception_handler(custom_handler)
|
||||||
|
|
||||||
|
# Create tasks.
|
||||||
|
event_a = asyncio.Event()
|
||||||
|
event_b = asyncio.Event()
|
||||||
|
tasks = []
|
||||||
|
if start_task_that_runs:
|
||||||
|
tasks.append(asyncio.create_task(task_that_runs()))
|
||||||
|
tasks.append(asyncio.create_task(task_that_finishes_early("a", event_a, task_fail)))
|
||||||
|
tasks.append(asyncio.create_task(task_that_finishes_early("b", event_b, task_fail)))
|
||||||
|
|
||||||
|
# Make sure task_that_finishes_early() are both done, before calling gather().
|
||||||
|
await event_a.wait()
|
||||||
|
await event_b.wait()
|
||||||
|
|
||||||
|
# Gather the tasks.
|
||||||
|
try:
|
||||||
|
result = "complete", await asyncio.gather(*tasks, return_exceptions=return_exceptions)
|
||||||
|
except Exception as er:
|
||||||
|
result = "exception", er, start_task_that_runs and tasks[0].done()
|
||||||
|
|
||||||
|
# Wait for the final task to finish (if it was started).
|
||||||
|
if start_task_that_runs:
|
||||||
|
await tasks[0]
|
||||||
|
|
||||||
|
# Print results.
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
# Run the test in the 8 different combinations of its arguments.
|
||||||
|
for i in range(8):
|
||||||
|
asyncio.run(main(bool(i & 4), bool(i & 2), bool(i & 1)))
|
Loading…
x
Reference in New Issue
Block a user