Better handling of size limit when resizing a table

Avoid silent conversions from int to unsigned int when calling
'luaH_resize'; avoid silent conversions from lua_Integer to int in
'table.create'; MAXASIZE corrected for the new implementation of arrays;
'luaH_resize' checks explicitly whether new size respects MAXASIZE.
(Even constructors were bypassing that check.)
This commit is contained in:
Roberto Ierusalimschy 2024-02-07 13:39:54 -03:00
parent c31d6774ac
commit 0c9bec0d38
6 changed files with 53 additions and 37 deletions

2
lapi.c
View File

@ -781,7 +781,7 @@ LUA_API int lua_rawgetp (lua_State *L, int idx, const void *p) {
}
LUA_API void lua_createtable (lua_State *L, int narray, int nrec) {
LUA_API void lua_createtable (lua_State *L, unsigned narray, unsigned nrec) {
Table *t;
lua_lock(L);
t = luaH_new(L);

View File

@ -61,18 +61,25 @@ typedef union {
/*
** MAXABITS is the largest integer such that MAXASIZE fits in an
** MAXABITS is the largest integer such that 2^MAXABITS fits in an
** unsigned int.
*/
#define MAXABITS cast_int(sizeof(int) * CHAR_BIT - 1)
/*
** MAXASIZE is the maximum size of the array part. It is the minimum
** between 2^MAXABITS and the maximum size that, measured in bytes,
** fits in a 'size_t'.
** MAXASIZEB is the maximum number of elements in the array part such
** that the size of the array fits in 'size_t'.
*/
#define MAXASIZE luaM_limitN(1u << MAXABITS, TValue)
#define MAXASIZEB ((MAX_SIZET/sizeof(ArrayCell)) * NM)
/*
** MAXASIZE is the maximum size of the array part. It is the minimum
** between 2^MAXABITS and MAXASIZEB.
*/
#define MAXASIZE \
(((1u << MAXABITS) < MAXASIZEB) ? (1u << MAXABITS) : cast_uint(MAXASIZEB))
/*
** MAXHBITS is the largest integer such that 2^MAXHBITS fits in a
@ -663,6 +670,8 @@ void luaH_resize (lua_State *L, Table *t, unsigned int newasize,
Table newt; /* to keep the new hash part */
unsigned int oldasize = setlimittosize(t);
ArrayCell *newarray;
if (newasize > MAXASIZE)
luaG_runerror(L, "table overflow");
/* create new hash part with appropriate size into 'newt' */
newt.flags = 0;
setnodevector(L, &newt, nhsize);

View File

@ -59,8 +59,10 @@ static void checktab (lua_State *L, int arg, int what) {
static int tcreate (lua_State *L) {
int sizeseq = (int)luaL_checkinteger(L, 1);
int sizerest = (int)luaL_optinteger(L, 2, 0);
lua_Unsigned sizeseq = (lua_Unsigned)luaL_checkinteger(L, 1);
lua_Unsigned sizerest = (lua_Unsigned)luaL_optinteger(L, 2, 0);
luaL_argcheck(L, sizeseq <= UINT_MAX, 1, "out of range");
luaL_argcheck(L, sizerest <= UINT_MAX, 2, "out of range");
lua_createtable(L, sizeseq, sizerest);
return 1;
}

2
lua.h
View File

@ -268,7 +268,7 @@ LUA_API int (lua_rawget) (lua_State *L, int idx);
LUA_API int (lua_rawgeti) (lua_State *L, int idx, lua_Integer n);
LUA_API int (lua_rawgetp) (lua_State *L, int idx, const void *p);
LUA_API void (lua_createtable) (lua_State *L, int narr, int nrec);
LUA_API void (lua_createtable) (lua_State *L, unsigned narr, unsigned nrec);
LUA_API void *(lua_newuserdatauv) (lua_State *L, size_t sz, int nuvalue);
LUA_API int (lua_getmetatable) (lua_State *L, int objindex);
LUA_API int (lua_getiuservalue) (lua_State *L, int idx, int n);

View File

@ -3234,7 +3234,7 @@ Values at other positions are not affected.
}
@APIEntry{void lua_createtable (lua_State *L, int nseq, int nrec);|
@APIEntry{void lua_createtable (lua_State *L, unsigned nseq, unsigned nrec);|
@apii{0,1,m}
Creates a new empty table and pushes it onto the stack.

View File

@ -3,33 +3,6 @@
print "testing (parts of) table library"
do print "testing 'table.create'"
collectgarbage()
local m = collectgarbage("count") * 1024
local t = table.create(10000)
local memdiff = collectgarbage("count") * 1024 - m
assert(memdiff > 10000 * 4)
for i = 1, 20 do
assert(#t == i - 1)
t[i] = 0
end
for i = 1, 20 do t[#t + 1] = i * 10 end
assert(#t == 40 and t[39] == 190)
assert(not T or T.querytab(t) == 10000)
t = nil
collectgarbage()
m = collectgarbage("count") * 1024
t = table.create(0, 1024)
memdiff = collectgarbage("count") * 1024 - m
assert(memdiff > 1024 * 12)
assert(not T or select(2, T.querytab(t)) == 1024)
end
print "testing unpack"
local unpack = table.unpack
local maxI = math.maxinteger
local minI = math.mininteger
@ -40,6 +13,38 @@ local function checkerror (msg, f, ...)
end
do print "testing 'table.create'"
local N = 10000
collectgarbage()
local m = collectgarbage("count") * 1024
local t = table.create(N)
local memdiff = collectgarbage("count") * 1024 - m
assert(memdiff > N * 4)
for i = 1, 20 do
assert(#t == i - 1)
t[i] = 0
end
for i = 1, 20 do t[#t + 1] = i * 10 end
assert(#t == 40 and t[39] == 190)
assert(not T or T.querytab(t) == N)
t = nil
collectgarbage()
m = collectgarbage("count") * 1024
t = table.create(0, 1024)
memdiff = collectgarbage("count") * 1024 - m
assert(memdiff > 1024 * 12)
assert(not T or select(2, T.querytab(t)) == 1024)
checkerror("table overflow", table.create, (1<<31) + 1)
checkerror("table overflow", table.create, 0, (1<<31) + 1)
end
print "testing unpack"
local unpack = table.unpack
checkerror("wrong number of arguments", table.insert, {}, 2, 3, 4)
local x,y,z,a,n