diff --git a/py/objlist.c b/py/objlist.c index ade062e07c..0ff3b1d53d 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -31,6 +31,7 @@ #include "py/objlist.h" #include "py/runtime0.h" #include "py/runtime.h" +#include "py/stackctrl.h" STATIC mp_obj_t mp_obj_new_list_iterator(mp_obj_list_t *list, mp_uint_t cur); STATIC mp_obj_list_t *list_new(mp_uint_t n); @@ -284,16 +285,15 @@ STATIC mp_obj_t list_pop(mp_uint_t n_args, const mp_obj_t *args) { return ret; } -// TODO make this conform to CPython's definition of sort -STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool reversed) { - mp_uint_t op = reversed ? MP_BINARY_OP_MORE : MP_BINARY_OP_LESS; +STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, mp_obj_t binop_less_result) { + MP_STACK_CHECK(); while (head < tail) { mp_obj_t *h = head - 1; mp_obj_t *t = tail; - mp_obj_t v = key_fn == NULL ? tail[0] : mp_call_function_1(key_fn, tail[0]); // get pivot using key_fn + mp_obj_t v = key_fn == MP_OBJ_NULL ? tail[0] : mp_call_function_1(key_fn, tail[0]); // get pivot using key_fn for (;;) { - do ++h; while (mp_binary_op(op, key_fn == NULL ? h[0] : mp_call_function_1(key_fn, h[0]), v) == mp_const_true); - do --t; while (h < t && mp_binary_op(op, v, key_fn == NULL ? t[0] : mp_call_function_1(key_fn, t[0])) == mp_const_true); + do ++h; while (h < t && mp_binary_op(MP_BINARY_OP_LESS, key_fn == MP_OBJ_NULL ? h[0] : mp_call_function_1(key_fn, h[0]), v) == binop_less_result); + do --t; while (h < t && mp_binary_op(MP_BINARY_OP_LESS, v, key_fn == MP_OBJ_NULL ? t[0] : mp_call_function_1(key_fn, t[0])) == binop_less_result); if (h >= t) break; mp_obj_t x = h[0]; h[0] = t[0]; @@ -302,27 +302,38 @@ STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool r mp_obj_t x = h[0]; h[0] = tail[0]; tail[0] = x; - mp_quicksort(head, t, key_fn, reversed); - head = h + 1; + // do the smaller recursive call first, to keep stack within O(log(N)) + if (t - head < tail - h - 1) { + mp_quicksort(head, t, key_fn, binop_less_result); + head = h + 1; + } else { + mp_quicksort(h + 1, tail, key_fn, binop_less_result); + tail = t; + } } } -mp_obj_t mp_obj_list_sort(mp_uint_t n_args, const mp_obj_t *args, mp_map_t *kwargs) { - assert(n_args >= 1); - assert(MP_OBJ_IS_TYPE(args[0], &mp_type_list)); - if (n_args > 1) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, - "list.sort takes no positional arguments")); - } - mp_obj_list_t *self = args[0]; +// TODO Python defines sort to be stable but ours is not +mp_obj_t mp_obj_list_sort(mp_uint_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + static const mp_arg_t allowed_args[] = { + { MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_none} }, + { MP_QSTR_reverse, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, + }; + + // parse args + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + mp_obj_list_t *self = pos_args[0]; + assert(MP_OBJ_IS_TYPE(self, &mp_type_list)); + if (self->len > 1) { - mp_map_elem_t *keyfun = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_key), MP_MAP_LOOKUP); - mp_map_elem_t *reverse = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_reverse), MP_MAP_LOOKUP); mp_quicksort(self->items, self->items + self->len - 1, - keyfun ? keyfun->value : NULL, - reverse && reverse->value ? mp_obj_is_true(reverse->value) : false); + args[0].u_obj == mp_const_none ? MP_OBJ_NULL : args[0].u_obj, + args[1].u_bool ? mp_const_false : mp_const_true); } - return mp_const_none; // return None, as per CPython + + return mp_const_none; } STATIC mp_obj_t list_clear(mp_obj_t self_in) { @@ -412,7 +423,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_3(list_insert_obj, list_insert); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(list_pop_obj, 1, 2, list_pop); STATIC MP_DEFINE_CONST_FUN_OBJ_2(list_remove_obj, list_remove); STATIC MP_DEFINE_CONST_FUN_OBJ_1(list_reverse_obj, list_reverse); -STATIC MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, 0, mp_obj_list_sort); +STATIC MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, 1, mp_obj_list_sort); STATIC const mp_map_elem_t list_locals_dict_table[] = { { MP_OBJ_NEW_QSTR(MP_QSTR_append), (mp_obj_t)&list_append_obj }, diff --git a/tests/basics/list_sort.py b/tests/basics/list_sort.py index e323ff1c2c..c185ddcd16 100644 --- a/tests/basics/list_sort.py +++ b/tests/basics/list_sort.py @@ -26,3 +26,24 @@ l.sort(reverse=False) print(l) print(l == sorted(l, reverse=False)) +# test large lists (should not stack overflow) +l = list(range(2000)) +l.sort() +print(l[0], l[-1]) +l.sort(reverse=True) +print(l[0], l[-1]) + +# test user-defined ordering +class A: + def __init__(self, x): + self.x = x + def __lt__(self, other): + return self.x > other.x + def __repr__(self): + return str(self.x) +l = [A(5), A(2), A(1), A(3), A(4)] +print(l) +l.sort() +print(l) +l.sort(reverse=True) +print(l)