From ce101dcaf73ff6d610593230d41b63c163a91519 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Wed, 30 Dec 2020 11:20:22 -0300 Subject: [PATCH] Handles '__close' errors in coroutines in "coroutine style" Errors in '__close' metamethods in coroutines are handled by the same logic that handles other errors, through 'recover'. --- ldo.c | 66 ++++++++++++++++++++++++++++++-------------- testes/coroutine.lua | 41 +++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/ldo.c b/ldo.c index 5e3828f4..ba0c93b8 100644 --- a/ldo.c +++ b/ldo.c @@ -103,7 +103,7 @@ void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop) { break; } default: { - lua_assert(errcode >= LUA_ERRRUN); /* real error */ + lua_assert(errorstatus(errcode)); /* real error */ setobjs2s(L, oldtop, L->top - 1); /* error message on current top */ break; } @@ -593,15 +593,11 @@ static void finishCcall (lua_State *L, int status) { /* ** Executes "full continuation" (everything in the stack) of a ** previously interrupted coroutine until the stack is empty (or another -** interruption long-jumps out of the loop). If the coroutine is -** recovering from an error, 'ud' points to the error status, which must -** be passed to the first continuation function (otherwise the default -** status is LUA_YIELD). +** interruption long-jumps out of the loop). */ static void unroll (lua_State *L, void *ud) { CallInfo *ci; - if (ud != NULL) /* error status? */ - finishCcall(L, *(int *)ud); /* finish 'lua_pcallk' callee */ + UNUSED(ud); while ((ci = L->ci) != &L->base_ci) { /* something in the stack */ if (!isLua(ci)) /* C function? */ finishCcall(L, LUA_YIELD); /* complete its execution */ @@ -628,21 +624,36 @@ static CallInfo *findpcall (lua_State *L) { /* -** Recovers from an error in a coroutine. Finds a recover point (if -** there is one) and completes the execution of the interrupted -** 'luaD_pcall'. If there is no recover point, returns zero. +** Auxiliary structure to call 'recover' in protected mode. */ -static int recover (lua_State *L, int status) { - CallInfo *ci = findpcall(L); - if (ci == NULL) return 0; /* no recovery point */ +struct RecoverS { + int status; + CallInfo *ci; +}; + + +/* +** Recovers from an error in a coroutine: completes the execution of the +** interrupted 'luaD_pcall', completes the interrupted C function which +** called 'lua_pcallk', and continues running the coroutine. If there is +** an error in 'luaF_close', this function will be called again and the +** coroutine will continue from where it left. +*/ +static void recover (lua_State *L, void *ud) { + struct RecoverS *r = cast(struct RecoverS *, ud); + int status = r->status; + CallInfo *ci = r->ci; /* recover point */ + StkId func = restorestack(L, ci->u2.funcidx); /* "finish" luaD_pcall */ L->ci = ci; L->allowhook = getoah(ci->callstatus); /* restore original 'allowhook' */ - status = luaD_closeprotected(L, ci->u2.funcidx, status); - luaD_seterrorobj(L, status, restorestack(L, ci->u2.funcidx)); + luaF_close(L, func, status); /* may change the stack */ + func = restorestack(L, ci->u2.funcidx); + luaD_seterrorobj(L, status, func); luaD_shrinkstack(L); /* restore stack size in case of overflow */ L->errfunc = ci->u.c.old_errfunc; - return 1; /* continue running the coroutine */ + finishCcall(L, status); /* finish 'lua_pcallk' callee */ + unroll(L, NULL); /* continue running the coroutine */ } @@ -692,6 +703,24 @@ static void resume (lua_State *L, void *ud) { } } + +/* +** Calls 'recover' in protected mode, repeating while there are +** recoverable errors, that is, errors inside a protected call. (Any +** error interrupts 'recover', and this loop protects it again so it +** can continue.) Stops with a normal end (status == LUA_OK), an yield +** (status == LUA_YIELD), or an unprotected error ('findpcall' doesn't +** find a recover point). +*/ +static int p_recover (lua_State *L, int status) { + struct RecoverS r; + r.status = status; + while (errorstatus(status) && (r.ci = findpcall(L)) != NULL) + r.status = luaD_rawrunprotected(L, recover, &r); + return r.status; +} + + LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs, int *nresults) { int status; @@ -709,10 +738,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs, api_checknelems(L, (L->status == LUA_OK) ? nargs + 1 : nargs); status = luaD_rawrunprotected(L, resume, &nargs); /* continue running after recoverable errors */ - while (errorstatus(status) && recover(L, status)) { - /* unroll continuation */ - status = luaD_rawrunprotected(L, unroll, &status); - } + status = p_recover(L, status); if (likely(!errorstatus(status))) lua_assert(status == L->status); /* normal end or yield */ else { /* unrecoverable error */ diff --git a/testes/coroutine.lua b/testes/coroutine.lua index 0a970e98..fbeabd07 100644 --- a/testes/coroutine.lua +++ b/testes/coroutine.lua @@ -123,7 +123,7 @@ assert(#a == 22 and a[#a] == 79) x, a = nil --- coroutine closing +print("to-be-closed variables in coroutines") local function func2close (f) return setmetatable({}, {__close = f}) @@ -189,7 +189,6 @@ do local st, msg = coroutine.close(co) assert(st == false and coroutine.status(co) == "dead" and msg == 200) assert(x == 200) - end do @@ -207,6 +206,44 @@ do local st1, st2, err = coroutine.resume(co) assert(st1 and not st2 and err == 43) assert(X == 43 and Y.name == "pcall") + + -- recovering from errors in __close metamethods + local track = {} + + local function h (o) + local hv = o + return 1 + end + + local function foo () + local x = func2close(function(_,msg) + track[#track + 1] = msg or false + error(20) + end) + local y = func2close(function(_,msg) + track[#track + 1] = msg or false + return 1000 + end) + local z = func2close(function(_,msg) + track[#track + 1] = msg or false + error(10) + end) + coroutine.yield(1) + h(func2close(function(_,msg) + track[#track + 1] = msg or false + error(2) + end)) + end + + local co = coroutine.create(pcall) + + local st, res = coroutine.resume(co, foo) -- call 'foo' protected + assert(st and res == 1) -- yield 1 + local st, res1, res2 = coroutine.resume(co) -- continue + assert(coroutine.status(co) == "dead") + assert(st and not res1 and res2 == 20) -- last error (20) + assert(track[1] == false and track[2] == 2 and track[3] == 10 and + track[4] == 10) end