'lua_upvalueid' returns NULL on invalid upvalue index

This commit is contained in:
Roberto Ierusalimschy 2020-10-12 14:51:28 -03:00
parent 9a89fb1c9d
commit 30528049f1
4 changed files with 33 additions and 15 deletions

3
.gitignore vendored
View File

@ -10,3 +10,6 @@ testes/time.txt
testes/time-debug.txt testes/time-debug.txt
testes/libs/all testes/libs/all
temp
lua

15
lapi.c
View File

@ -1383,13 +1383,16 @@ LUA_API const char *lua_setupvalue (lua_State *L, int funcindex, int n) {
static UpVal **getupvalref (lua_State *L, int fidx, int n, LClosure **pf) { static UpVal **getupvalref (lua_State *L, int fidx, int n, LClosure **pf) {
static const UpVal *const nullup = NULL;
LClosure *f; LClosure *f;
TValue *fi = index2value(L, fidx); TValue *fi = index2value(L, fidx);
api_check(L, ttisLclosure(fi), "Lua function expected"); api_check(L, ttisLclosure(fi), "Lua function expected");
f = clLvalue(fi); f = clLvalue(fi);
api_check(L, (1 <= n && n <= f->p->sizeupvalues), "invalid upvalue index");
if (pf) *pf = f; if (pf) *pf = f;
if (1 <= n && n <= f->p->sizeupvalues)
return &f->upvals[n - 1]; /* get its upvalue pointer */ return &f->upvals[n - 1]; /* get its upvalue pointer */
else
return (UpVal**)&nullup;
} }
@ -1401,11 +1404,14 @@ LUA_API void *lua_upvalueid (lua_State *L, int fidx, int n) {
} }
case LUA_VCCL: { /* C closure */ case LUA_VCCL: { /* C closure */
CClosure *f = clCvalue(fi); CClosure *f = clCvalue(fi);
api_check(L, 1 <= n && n <= f->nupvalues, "invalid upvalue index"); if (1 <= n && n <= f->nupvalues)
return &f->upvalue[n - 1]; return &f->upvalue[n - 1];
} /* else */
} /* FALLTHROUGH */
case LUA_VLCF:
return NULL; /* light C functions have no upvalues */
default: { default: {
api_check(L, 0, "closure expected"); api_check(L, 0, "function expected");
return NULL; return NULL;
} }
} }
@ -1417,6 +1423,7 @@ LUA_API void lua_upvaluejoin (lua_State *L, int fidx1, int n1,
LClosure *f1; LClosure *f1;
UpVal **up1 = getupvalref(L, fidx1, n1, &f1); UpVal **up1 = getupvalref(L, fidx1, n1, &f1);
UpVal **up2 = getupvalref(L, fidx2, n2, NULL); UpVal **up2 = getupvalref(L, fidx2, n2, NULL);
api_check(L, *up1 != NULL && *up2 != NULL, "invalid upvalue index");
*up1 = *up2; *up1 = *up2;
luaC_objbarrier(L, f1, *up1); luaC_objbarrier(L, f1, *up1);
} }

View File

@ -281,25 +281,33 @@ static int db_setupvalue (lua_State *L) {
** Check whether a given upvalue from a given closure exists and ** Check whether a given upvalue from a given closure exists and
** returns its index ** returns its index
*/ */
static int checkupval (lua_State *L, int argf, int argnup) { static void *checkupval (lua_State *L, int argf, int argnup, int *pnup) {
void *id;
int nup = (int)luaL_checkinteger(L, argnup); /* upvalue index */ int nup = (int)luaL_checkinteger(L, argnup); /* upvalue index */
luaL_checktype(L, argf, LUA_TFUNCTION); /* closure */ luaL_checktype(L, argf, LUA_TFUNCTION); /* closure */
luaL_argcheck(L, (lua_getupvalue(L, argf, nup) != NULL), argnup, id = lua_upvalueid(L, argf, nup);
"invalid upvalue index"); if (pnup) {
return nup; luaL_argcheck(L, id != NULL, argnup, "invalid upvalue index");
*pnup = nup;
}
return id;
} }
static int db_upvalueid (lua_State *L) { static int db_upvalueid (lua_State *L) {
int n = checkupval(L, 1, 2); void *id = checkupval(L, 1, 2, NULL);
lua_pushlightuserdata(L, lua_upvalueid(L, 1, n)); if (id != NULL)
lua_pushlightuserdata(L, id);
else
luaL_pushfail(L);
return 1; return 1;
} }
static int db_upvaluejoin (lua_State *L) { static int db_upvaluejoin (lua_State *L) {
int n1 = checkupval(L, 1, 2); int n1, n2;
int n2 = checkupval(L, 3, 4); checkupval(L, 1, 2, &n1);
checkupval(L, 3, 4, &n2);
luaL_argcheck(L, !lua_iscfunction(L, 1), 1, "Lua function expected"); luaL_argcheck(L, !lua_iscfunction(L, 1), 1, "Lua function expected");
luaL_argcheck(L, !lua_iscfunction(L, 3), 3, "Lua function expected"); luaL_argcheck(L, !lua_iscfunction(L, 3), 3, "Lua function expected");
lua_upvaluejoin(L, 1, n1, 3, n2); lua_upvaluejoin(L, 1, n1, 3, n2);

View File

@ -242,7 +242,7 @@ end
assert(debug.upvalueid(foo1, 1)) assert(debug.upvalueid(foo1, 1))
assert(debug.upvalueid(foo1, 2)) assert(debug.upvalueid(foo1, 2))
assert(not pcall(debug.upvalueid, foo1, 3)) assert(not debug.upvalueid(foo1, 3))
assert(debug.upvalueid(foo1, 1) == debug.upvalueid(foo2, 2)) assert(debug.upvalueid(foo1, 1) == debug.upvalueid(foo2, 2))
assert(debug.upvalueid(foo1, 2) == debug.upvalueid(foo2, 1)) assert(debug.upvalueid(foo1, 2) == debug.upvalueid(foo2, 1))
assert(debug.upvalueid(foo3, 1)) assert(debug.upvalueid(foo3, 1))