in 'table.sort': tighter checks for invalid order function +

"random" pivot for larger intervals (to avoid attacks with
bad data)
This commit is contained in:
Roberto Ierusalimschy 2015-11-12 16:07:25 -02:00
parent 330d426ffd
commit bde03eeb48

View File

@ -1,5 +1,5 @@
/*
** $Id: ltablib.c,v 1.83 2015/09/17 15:53:50 roberto Exp roberto $
** $Id: ltablib.c,v 1.84 2015/11/06 16:07:14 roberto Exp roberto $
** Library for Table Manipulation
** See Copyright Notice in lua.h
*/
@ -233,7 +233,6 @@ static int unpack (lua_State *L) {
** =======================================================
*/
static void set2 (lua_State *L, int i, int j) {
lua_seti(L, 1, i);
lua_seti(L, 1, j);
@ -269,14 +268,14 @@ static int partition (lua_State *L, int lo, int up) {
for (;;) {
/* next loop: repeat ++i while a[i] < P */
while (lua_geti(L, 1, ++i), sort_comp(L, -1, -2)) {
if (i >= up)
if (i == up - 1) /* a[i] < P but a[up - 1] == P ?? */
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[i] */
}
/* after the loop, a[i] >= P and a[lo .. i - 1] < P */
/* next loop: repeat --j while P < a[j] */
while (lua_geti(L, 1, --j), sort_comp(L, -3, -1)) {
if (j < lo)
if (j < i) /* j < i but a[j] > P ?? */
luaL_error(L, "invalid order function for sorting");
lua_pop(L, 1); /* remove a[j] */
}
@ -294,6 +293,20 @@ static int partition (lua_State *L, int lo, int up) {
}
/*
** Choose a "random" pivot in the middle part of the interval [lo, up].
** Use 'time' and 'clock' as sources of "randomness".
*/
static int choosePivot (int lo, int up) {
unsigned int t = (unsigned int)(unsigned long)time(NULL); /* time */
unsigned int c = (unsigned int)(unsigned long)clock(); /* clock */
unsigned int r4 = (unsigned int)(up - lo) / 4u; /* range/4 */
unsigned int p = (c + t) % (r4 * 2) + (lo + r4);
lua_assert(lo + r4 <= p && p <= up - r4);
return (int)p;
}
static void auxsort (lua_State *L, int lo, int up) {
while (lo < up) { /* loop for tail recursion */
int p;
@ -306,7 +319,10 @@ static void auxsort (lua_State *L, int lo, int up) {
lua_pop(L, 2); /* remove both values */
if (up - lo == 1) /* only 2 elements? */
return; /* already sorted */
p = (lo + up)/2;
if (up - lo < 100) /* small interval? */
p = (lo + up)/2; /* middle element is a good pivot */
else /* for larger intervals, it is worth a random pivot */
p = choosePivot(lo, up);
lua_geti(L, 1, p);
lua_geti(L, 1, lo);
if (sort_comp(L, -2, -1)) /* a[p] < a[lo]? */
@ -338,6 +354,7 @@ static void auxsort (lua_State *L, int lo, int up) {
} /* tail call auxsort(L, lo, up) */
}
static int sort (lua_State *L) {
int n = (int)aux_getn(L, 1, TAB_RW);
luaL_checkstack(L, 50, ""); /* assume array is smaller than 2^50 */