From 88a50ffa715483e7187c0d7d6caaf708ebacf756 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Fri, 29 Mar 2024 15:10:50 -0300 Subject: [PATCH] Fixed dangling 'StkId' in 'luaV_finishget' Bug introduced in 05932567. --- lobject.h | 2 ++ ltm.c | 43 ++++++++++++++++++++++++------------------- ltm.h | 4 ++-- lvm.c | 10 +++++----- testes/events.lua | 9 +++++++++ 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/lobject.h b/lobject.h index b42539cf..169512f8 100644 --- a/lobject.h +++ b/lobject.h @@ -256,6 +256,8 @@ typedef union { #define l_isfalse(o) (ttisfalse(o) || ttisnil(o)) +#define tagisfalse(t) ((t) == LUA_VFALSE || novariant(t) == LUA_TNIL) + #define setbfvalue(obj) settt_(obj, LUA_VFALSE) diff --git a/ltm.c b/ltm.c index c28f9122..236f3bb4 100644 --- a/ltm.c +++ b/ltm.c @@ -116,8 +116,8 @@ void luaT_callTM (lua_State *L, const TValue *f, const TValue *p1, } -void luaT_callTMres (lua_State *L, const TValue *f, const TValue *p1, - const TValue *p2, StkId res) { +int luaT_callTMres (lua_State *L, const TValue *f, const TValue *p1, + const TValue *p2, StkId res) { ptrdiff_t result = savestack(L, res); StkId func = L->top.p; setobj2s(L, func, f); /* push function (assume EXTRA_STACK) */ @@ -131,6 +131,7 @@ void luaT_callTMres (lua_State *L, const TValue *f, const TValue *p1, luaD_callnoyield(L, func, 1); res = restorestack(L, result); setobjs2s(L, res, --L->top.p); /* move result to its place */ + return ttypetag(s2v(res)); /* return tag of the result */ } @@ -139,15 +140,16 @@ static int callbinTM (lua_State *L, const TValue *p1, const TValue *p2, const TValue *tm = luaT_gettmbyobj(L, p1, event); /* try first operand */ if (notm(tm)) tm = luaT_gettmbyobj(L, p2, event); /* try second operand */ - if (notm(tm)) return 0; - luaT_callTMres(L, tm, p1, p2, res); - return 1; + if (notm(tm)) + return -1; /* tag method not found */ + else /* call tag method and return the tag of the result */ + return luaT_callTMres(L, tm, p1, p2, res); } void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, TMS event) { - if (l_unlikely(!callbinTM(L, p1, p2, res, event))) { + if (l_unlikely(callbinTM(L, p1, p2, res, event) < 0)) { switch (event) { case TM_BAND: case TM_BOR: case TM_BXOR: case TM_SHL: case TM_SHR: case TM_BNOT: { @@ -164,11 +166,14 @@ void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, } +/* +** The use of 'p1' after 'callbinTM' is safe because, when a tag +** method is not found, 'callbinTM' cannot change the stack. +*/ void luaT_tryconcatTM (lua_State *L) { - StkId top = L->top.p; - if (l_unlikely(!callbinTM(L, s2v(top - 2), s2v(top - 1), top - 2, - TM_CONCAT))) - luaG_concaterror(L, s2v(top - 2), s2v(top - 1)); + StkId p1 = L->top.p - 2; /* first argument */ + if (l_unlikely(callbinTM(L, s2v(p1), s2v(p1 + 1), p1, TM_CONCAT) < 0)) + luaG_concaterror(L, s2v(p1), s2v(p1 + 1)); } @@ -200,17 +205,17 @@ void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2, */ int luaT_callorderTM (lua_State *L, const TValue *p1, const TValue *p2, TMS event) { - if (callbinTM(L, p1, p2, L->top.p, event)) /* try original event */ - return !l_isfalse(s2v(L->top.p)); + int tag = callbinTM(L, p1, p2, L->top.p, event); /* try original event */ + if (tag >= 0) /* found tag method? */ + return !tagisfalse(tag); #if defined(LUA_COMPAT_LT_LE) else if (event == TM_LE) { - /* try '!(p2 < p1)' for '(p1 <= p2)' */ - L->ci->callstatus |= CIST_LEQ; /* mark it is doing 'lt' for 'le' */ - if (callbinTM(L, p2, p1, L->top.p, TM_LT)) { - L->ci->callstatus ^= CIST_LEQ; /* clear mark */ - return l_isfalse(s2v(L->top.p)); - } - /* else error will remove this 'ci'; no need to clear mark */ + /* try '!(p2 < p1)' for '(p1 <= p2)' */ + L->ci->callstatus |= CIST_LEQ; /* mark it is doing 'lt' for 'le' */ + tag = callbinTM(L, p2, p1, L->top.p, TM_LT); + L->ci->callstatus ^= CIST_LEQ; /* clear mark */ + if (tag >= 0) /* found tag method? */ + return tagisfalse(tag); } #endif luaG_ordererror(L, p1, p2); /* no metamethod found */ diff --git a/ltm.h b/ltm.h index 3c49713a..df05b741 100644 --- a/ltm.h +++ b/ltm.h @@ -81,8 +81,8 @@ LUAI_FUNC void luaT_init (lua_State *L); LUAI_FUNC void luaT_callTM (lua_State *L, const TValue *f, const TValue *p1, const TValue *p2, const TValue *p3); -LUAI_FUNC void luaT_callTMres (lua_State *L, const TValue *f, - const TValue *p1, const TValue *p2, StkId p3); +LUAI_FUNC int luaT_callTMres (lua_State *L, const TValue *f, + const TValue *p1, const TValue *p2, StkId p3); LUAI_FUNC void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, TMS event); LUAI_FUNC void luaT_tryconcatTM (lua_State *L); diff --git a/lvm.c b/lvm.c index cfa9961b..37023afb 100644 --- a/lvm.c +++ b/lvm.c @@ -308,8 +308,8 @@ int luaV_finishget (lua_State *L, const TValue *t, TValue *key, StkId val, /* else will try the metamethod */ } if (ttisfunction(tm)) { /* is metamethod a function? */ - luaT_callTMres(L, tm, t, key, val); /* call it */ - return ttypetag(s2v(val)); + tag = luaT_callTMres(L, tm, t, key, val); /* call it */ + return tag; /* return tag of the result */ } t = tm; /* else try to access 'tm[key]' */ luaV_fastget(t, key, s2v(val), luaH_get, tag); @@ -606,8 +606,8 @@ int luaV_equalobj (lua_State *L, const TValue *t1, const TValue *t2) { if (tm == NULL) /* no TM? */ return 0; /* objects are different */ else { - luaT_callTMres(L, tm, t1, t2, L->top.p); /* call TM */ - return !l_isfalse(s2v(L->top.p)); + int tag = luaT_callTMres(L, tm, t1, t2, L->top.p); /* call TM */ + return !tagisfalse(tag); } } @@ -914,7 +914,7 @@ void luaV_finishOp (lua_State *L) { /* ** Auxiliary function for arithmetic operations over floats and others -** with two register operands. +** with two operands. */ #define op_arithf_aux(L,v1,v2,fop) { \ lua_Number n1; lua_Number n2; \ diff --git a/testes/events.lua b/testes/events.lua index 8d8563b9..5360ac30 100644 --- a/testes/events.lua +++ b/testes/events.lua @@ -248,6 +248,15 @@ end test(Op(1), Op(2), Op(3)) +do -- test nil as false + local x = setmetatable({12}, {__eq= function (a,b) + return a[1] == b[1] or nil + end}) + assert(not (x == {20})) + assert(x == {12}) +end + + -- test `partial order' local function rawSet(x)