From 686afc5c0aaf2bc5a8d2547b703ab3177e0ea569 Mon Sep 17 00:00:00 2001 From: Damien George Date: Fri, 11 Apr 2014 09:13:30 +0100 Subject: [PATCH] py: Check that sequence has 2 elements for dict iterable constructor. --- py/obj.c | 2 +- py/objdict.c | 6 ++++-- tests/basics/dict-from-iter.py | 10 ++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/py/obj.c b/py/obj.c index 844ec41216..623b396422 100644 --- a/py/obj.c +++ b/py/obj.c @@ -280,7 +280,7 @@ void mp_obj_get_array_fixed_n(mp_obj_t o, uint len, mp_obj_t **items) { mp_obj_list_get(o, &seq_len, items); } if (seq_len != len) { - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_IndexError, "requested length %d but object has length %d", len, seq_len)); + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "requested length %d but object has length %d", len, seq_len)); } } else { nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "object '%s' is not a tuple or list", mp_obj_get_type_str(o))); diff --git a/py/objdict.c b/py/objdict.c index 4dffa53da9..963e188074 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -50,9 +50,11 @@ STATIC mp_obj_t dict_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp mp_obj_t iterable = mp_getiter(args[0]); mp_obj_t dict = mp_obj_new_dict(0); // TODO: support arbitrary seq as a pair - mp_obj_tuple_t *item; + mp_obj_t item; while ((item = mp_iternext(iterable)) != MP_OBJ_NULL) { - mp_obj_dict_store(dict, item->items[0], item->items[1]); + mp_obj_t *sub_items; + mp_obj_get_array_fixed_n(item, 2, &sub_items); + mp_obj_dict_store(dict, sub_items[0], sub_items[1]); } return dict; } diff --git a/tests/basics/dict-from-iter.py b/tests/basics/dict-from-iter.py index 8215969224..dc76801ff6 100644 --- a/tests/basics/dict-from-iter.py +++ b/tests/basics/dict-from-iter.py @@ -2,3 +2,13 @@ print(dict([(1, "foo")])) d = dict([("foo", "foo2"), ("bar", "baz")]) print(sorted(d.keys())) print(sorted(d.values())) + +try: + dict(((1,),)) +except ValueError: + print("ValueError") + +try: + dict(((1, 2, 3),)) +except ValueError: + print("ValueError")