From b80077b8f3e27a94c6afa895b41a9f8b52c42e61 Mon Sep 17 00:00:00 2001 From: Roberto Ierusalimschy Date: Fri, 26 Jul 2019 14:59:39 -0300 Subject: [PATCH] Change in the handling of 'L->top' when calling metamethods Instead of updating 'L->top' in every place that may call a metamethod, the metamethod functions themselves (luaT_trybinTM and luaT_callorderTM) correct the top. (When calling metamethods from the C API, however, the callers must preserve 'L->top'.) --- lapi.c | 2 ++ lobject.c | 2 ++ lopcodes.c | 2 +- ltm.c | 12 +++++++++--- ltm.h | 1 + lvm.c | 44 ++++++++++++++++++++++++-------------------- testes/api.lua | 17 +++++++++++++++++ testes/coroutine.lua | 2 +- testes/events.lua | 15 +++++++++++++-- testes/strings.lua | 7 +++++-- 10 files changed, 75 insertions(+), 29 deletions(-) diff --git a/lapi.c b/lapi.c index 0ea3dc0f..a9ffad80 100644 --- a/lapi.c +++ b/lapi.c @@ -329,12 +329,14 @@ LUA_API int lua_compare (lua_State *L, int index1, int index2, int op) { o1 = index2value(L, index1); o2 = index2value(L, index2); if (isvalid(L, o1) && isvalid(L, o2)) { + ptrdiff_t top = savestack(L, L->top); switch (op) { case LUA_OPEQ: i = luaV_equalobj(L, o1, o2); break; case LUA_OPLT: i = luaV_lessthan(L, o1, o2); break; case LUA_OPLE: i = luaV_lessequal(L, o1, o2); break; default: api_check(L, 0, "invalid option"); } + L->top = restorestack(L, top); } lua_unlock(L); return i; diff --git a/lobject.c b/lobject.c index b4efae4f..b376ab15 100644 --- a/lobject.c +++ b/lobject.c @@ -127,7 +127,9 @@ void luaO_arith (lua_State *L, int op, const TValue *p1, const TValue *p2, StkId res) { if (!luaO_rawarith(L, op, p1, p2, s2v(res))) { /* could not perform raw operation; try metamethod */ + ptrdiff_t top = savestack(L, L->top); luaT_trybinTM(L, p1, p2, res, cast(TMS, (op - LUA_OPADD) + TM_ADD)); + L->top = restorestack(L, top); } } diff --git a/lopcodes.c b/lopcodes.c index 23c3a6e4..ee795786 100644 --- a/lopcodes.c +++ b/lopcodes.c @@ -101,7 +101,7 @@ LUAI_DDEF const lu_byte luaP_opmodes[NUM_OPCODES] = { ,opmode(0, 1, 0, 0, iABC) /* OP_SETLIST */ ,opmode(0, 0, 0, 1, iABx) /* OP_CLOSURE */ ,opmode(1, 0, 0, 1, iABC) /* OP_VARARG */ - ,opmode(0, 0, 0, 1, iABC) /* OP_VARARGPREP */ + ,opmode(0, 1, 0, 1, iABC) /* OP_VARARGPREP */ ,opmode(0, 0, 0, 0, iAx) /* OP_EXTRAARG */ }; diff --git a/ltm.c b/ltm.c index 24739444..19233a87 100644 --- a/ltm.c +++ b/ltm.c @@ -147,11 +147,9 @@ static int callbinTM (lua_State *L, const TValue *p1, const TValue *p2, void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, TMS event) { + L->top = L->ci->top; if (!callbinTM(L, p1, p2, res, event)) { switch (event) { - case TM_CONCAT: - luaG_concaterror(L, p1, p2); - /* call never returns, but to avoid warnings: *//* FALLTHROUGH */ case TM_BAND: case TM_BOR: case TM_BXOR: case TM_SHL: case TM_SHR: case TM_BNOT: { if (ttisnumber(p1) && ttisnumber(p2)) @@ -167,6 +165,13 @@ void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, } +void luaT_tryconcatTM (lua_State *L) { + StkId top = L->top; + if (!callbinTM(L, s2v(top - 2), s2v(top - 1), top - 2, TM_CONCAT)) + luaG_concaterror(L, s2v(top - 2), s2v(top - 1)); +} + + void luaT_trybinassocTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, int flip, TMS event) { if (flip) @@ -186,6 +191,7 @@ void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2, int luaT_callorderTM (lua_State *L, const TValue *p1, const TValue *p2, TMS event) { + L->top = L->ci->top; if (callbinTM(L, p1, p2, L->top, event)) /* try original event */ return !l_isfalse(s2v(L->top)); #if defined(LUA_COMPAT_LT_LE) diff --git a/ltm.h b/ltm.h index e308fb80..51dfe793 100644 --- a/ltm.h +++ b/ltm.h @@ -75,6 +75,7 @@ LUAI_FUNC void luaT_callTMres (lua_State *L, const TValue *f, const TValue *p1, const TValue *p2, StkId p3); LUAI_FUNC void luaT_trybinTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, TMS event); +LUAI_FUNC void luaT_tryconcatTM (lua_State *L); LUAI_FUNC void luaT_trybinassocTM (lua_State *L, const TValue *p1, const TValue *p2, StkId res, int inv, TMS event); LUAI_FUNC void luaT_trybiniTM (lua_State *L, const TValue *p1, lua_Integer i2, diff --git a/lvm.c b/lvm.c index 26477c2c..f177ce6a 100644 --- a/lvm.c +++ b/lvm.c @@ -515,8 +515,11 @@ int luaV_equalobj (lua_State *L, const TValue *t1, const TValue *t2) { } if (tm == NULL) /* no TM? */ return 0; /* objects are different */ - luaT_callTMres(L, tm, t1, t2, L->top); /* call TM */ - return !l_isfalse(s2v(L->top)); + else { + L->top = L->ci->top; + luaT_callTMres(L, tm, t1, t2, L->top); /* call TM */ + return !l_isfalse(s2v(L->top)); + } } @@ -548,7 +551,7 @@ void luaV_concat (lua_State *L, int total) { int n = 2; /* number of elements handled in this pass (at least 2) */ if (!(ttisstring(s2v(top - 2)) || cvt2str(s2v(top - 2))) || !tostring(L, s2v(top - 1))) - luaT_trybinTM(L, s2v(top - 2), s2v(top - 1), top - 2, TM_CONCAT); + luaT_tryconcatTM(L); else if (isemptystr(s2v(top - 1))) /* second operand is empty? */ cast_void(tostring(L, s2v(top - 2))); /* result is first operand */ else if (isemptystr(s2v(top - 2))) { /* first operand is empty string? */ @@ -747,7 +750,7 @@ void luaV_finishOp (lua_State *L) { break; } case OP_CONCAT: { - StkId top = L->top - 1; /* top when 'luaT_trybinTM' was called */ + StkId top = L->top - 1; /* top when 'luaT_tryconcatTM' was called */ int a = GETARG_A(inst); /* first element to concatenate */ int total = cast_int(top - 1 - (base + a)); /* yet to concatenate */ setobjs2s(L, top - 2, top); /* put TM result in proper position */ @@ -801,7 +804,7 @@ void luaV_finishOp (lua_State *L) { setfltvalue(s2v(ra), fop(L, nb, fimm)); \ } \ else \ - Protect(luaT_trybiniTM(L, v1, imm, flip, ra, tm)); } + ProtectNT(luaT_trybiniTM(L, v1, imm, flip, ra, tm)); } /* @@ -836,7 +839,7 @@ void luaV_finishOp (lua_State *L) { setfltvalue(s2v(ra), fop(L, n1, n2)); \ } \ else \ - Protect(luaT_trybinTM(L, v1, v2, ra, tm)); } + ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); } /* @@ -877,7 +880,7 @@ void luaV_finishOp (lua_State *L) { setfltvalue(s2v(ra), fop(L, n1, n2)); \ } \ else \ - Protect(luaT_trybinassocTM(L, v1, v2, ra, flip, tm)); } } + ProtectNT(luaT_trybinassocTM(L, v1, v2, ra, flip, tm)); } } /* @@ -891,7 +894,7 @@ void luaV_finishOp (lua_State *L) { setfltvalue(s2v(ra), fop(L, n1, n2)); \ } \ else \ - Protect(luaT_trybinTM(L, v1, v2, ra, tm)); } + ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); } /* @@ -906,7 +909,7 @@ void luaV_finishOp (lua_State *L) { setivalue(s2v(ra), op(L, i1, i2)); \ } \ else \ - Protect(luaT_trybiniTM(L, v1, i2, TESTARG_k(i), ra, tm)); } + ProtectNT(luaT_trybiniTM(L, v1, i2, TESTARG_k(i), ra, tm)); } /* @@ -920,7 +923,7 @@ void luaV_finishOp (lua_State *L) { setivalue(s2v(ra), op(L, i1, i2)); \ } \ else \ - Protect(luaT_trybinTM(L, v1, v2, ra, tm)); } + ProtectNT(luaT_trybinTM(L, v1, v2, ra, tm)); } /* @@ -937,7 +940,7 @@ void luaV_finishOp (lua_State *L) { else if (ttisnumber(s2v(ra)) && ttisnumber(rb)) \ cond = opf(s2v(ra), rb); \ else \ - Protect(cond = other(L, s2v(ra), rb)); \ + ProtectNT(cond = other(L, s2v(ra), rb)); \ docondjump(); } @@ -956,7 +959,7 @@ void luaV_finishOp (lua_State *L) { } \ else { \ int isf = GETARG_C(i); \ - Protect(cond = luaT_callorderiTM(L, s2v(ra), im, inv, isf, tm)); \ + ProtectNT(cond = luaT_callorderiTM(L, s2v(ra), im, inv, isf, tm)); \ } \ docondjump(); } @@ -1094,7 +1097,8 @@ void luaV_execute (lua_State *L, CallInfo *ci) { vmfetch(); lua_assert(base == ci->func + 1); lua_assert(base <= L->top && L->top < L->stack + L->stacksize); - lua_assert(ci->top < L->stack + L->stacksize); + /* invalidate top for instructions not expecting it */ + lua_assert(isIT(i) || (L->top = base)); vmdispatch (GET_OPCODE(i)) { vmcase(OP_MOVE) { setobjs2s(L, ra, RB(i)); @@ -1359,7 +1363,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { if (TESTARG_k(i)) { ic = -ic; ev = TM_SHL; } - Protect(luaT_trybiniTM(L, rb, ic, 0, ra, ev)); + ProtectNT(luaT_trybiniTM(L, rb, ic, 0, ra, ev)); } vmbreak; } @@ -1371,7 +1375,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { setivalue(s2v(ra), luaV_shiftl(ic, ib)); } else - Protect(luaT_trybiniTM(L, rb, ic, 1, ra, TM_SHL)); + ProtectNT(luaT_trybiniTM(L, rb, ic, 1, ra, TM_SHL)); vmbreak; } vmcase(OP_ADD) { @@ -1422,7 +1426,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { setivalue(s2v(ra), luaV_shiftl(ib, -ic)); } else - Protect(luaT_trybinTM(L, rb, rc, ra, TM_SHR)); + ProtectNT(luaT_trybinTM(L, rb, rc, ra, TM_SHR)); vmbreak; } vmcase(OP_SHL) { @@ -1433,7 +1437,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { setivalue(s2v(ra), luaV_shiftl(ib, ic)); } else - Protect(luaT_trybinTM(L, rb, rc, ra, TM_SHL)); + ProtectNT(luaT_trybinTM(L, rb, rc, ra, TM_SHL)); vmbreak; } vmcase(OP_UNM) { @@ -1447,7 +1451,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { setfltvalue(s2v(ra), luai_numunm(L, nb)); } else - Protect(luaT_trybinTM(L, rb, rb, ra, TM_UNM)); + ProtectNT(luaT_trybinTM(L, rb, rb, ra, TM_UNM)); vmbreak; } vmcase(OP_BNOT) { @@ -1457,7 +1461,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { setivalue(s2v(ra), intop(^, ~l_castS2U(0), ib)); } else - Protect(luaT_trybinTM(L, rb, rb, ra, TM_BNOT)); + ProtectNT(luaT_trybinTM(L, rb, rb, ra, TM_BNOT)); vmbreak; } vmcase(OP_NOT) { @@ -1493,7 +1497,7 @@ void luaV_execute (lua_State *L, CallInfo *ci) { vmcase(OP_EQ) { int cond; TValue *rb = vRB(i); - Protect(cond = luaV_equalobj(L, s2v(ra), rb)); + ProtectNT(cond = luaV_equalobj(L, s2v(ra), rb)); docondjump(); vmbreak; } diff --git a/testes/api.lua b/testes/api.lua index 5da03641..0966ed19 100644 --- a/testes/api.lua +++ b/testes/api.lua @@ -241,6 +241,23 @@ assert(a == 20 and b == false) a,b = T.testC("compare LE 5 -6, return 2", a1, 2, 2, a1, 2, 20) assert(a == 20 and b == true) + +do -- testing lessthan and lessequal with metamethods + local mt = {__lt = function (a,b) return a[1] < b[1] end, + __le = function (a,b) return a[1] <= b[1] end, + __eq = function (a,b) return a[1] == b[1] end} + local function O (x) + return setmetatable({x}, mt) + end + + local a, b = T.testC("compare LT 2 3; pushint 10; return 2", O(1), O(2)) + assert(a == true and b == 10) + local a, b = T.testC("compare LE 2 3; pushint 10; return 2", O(3), O(2)) + assert(a == false and b == 10) + local a, b = T.testC("compare EQ 2 3; pushint 10; return 2", O(3), O(3)) + assert(a == true and b == 10) +end + -- testing length local t = setmetatable({x = 20}, {__len = function (t) return t.x end}) a,b,c = T.testC([[ diff --git a/testes/coroutine.lua b/testes/coroutine.lua index e04207c8..00531d8e 100644 --- a/testes/coroutine.lua +++ b/testes/coroutine.lua @@ -809,7 +809,7 @@ assert(run(function () -- tests for coroutine API if T==nil then (Message or print)('\n >>> testC not active: skipping coroutine API tests <<<\n') - return + print "OK"; return end print('testing coroutine API') diff --git a/testes/events.lua b/testes/events.lua index cf68d1e9..7fb54c9a 100644 --- a/testes/events.lua +++ b/testes/events.lua @@ -217,9 +217,16 @@ t.__le = function (a,b,c) return a<=b, "dummy" end +t.__eq = function (a,b,c) + assert(c == nil) + if type(a) == 'table' then a = a.x end + if type(b) == 'table' then b = b.x end + return a == b, "dummy" +end + function Op(x) return setmetatable({x=x}, t) end -local function test () +local function test (a, b, c) assert(not(Op(1)= Op(1)) and not(1 >= Op(2)) and (Op(2) >= 1)) assert((Op('a')>=Op('a')) and not(Op('a')>=Op('b')) and (Op('b')>=Op('a'))) assert(('a' >= Op('a')) and not(Op('a') >= 'b') and (Op('b') >= Op('a'))) + assert(Op(1) == Op(1) and Op(1) ~= Op(2)) + assert(Op('a') == Op('a') and Op('a') ~= Op('b')) + assert(a == a and a ~= b) + assert(Op(3) == c) end -test() +test(Op(1), Op(2), Op(3)) -- test `partial order' diff --git a/testes/strings.lua b/testes/strings.lua index 0e7874bf..aa039c4f 100644 --- a/testes/strings.lua +++ b/testes/strings.lua @@ -167,8 +167,11 @@ do -- tests for '%p' format local t1 = {}; local t2 = {} assert(string.format("%p", t1) ~= string.format("%p", t2)) end - assert(string.format("%p", string.rep("a", 10)) == - string.format("%p", string.rep("a", 10))) -- short strings + do -- short strings + local s1 = string.rep("a", 10) + local s2 = string.rep("a", 10) + assert(string.format("%p", s1) == string.format("%p", s2)) + end do -- long strings local s1 = string.rep("a", 300); local s2 = string.rep("a", 300) assert(string.format("%p", s1) ~= string.format("%p", s2))