diff --git a/py/modstruct.c b/py/modstruct.c index eabc951aef..2016add17e 100644 --- a/py/modstruct.c +++ b/py/modstruct.c @@ -103,30 +103,24 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { char fmt_type = get_fmt_type(&fmt); mp_uint_t size; for (size = 0; *fmt; fmt++) { - mp_uint_t align = 1; mp_uint_t cnt = 1; if (unichar_isdigit(*fmt)) { cnt = get_fmt_num(&fmt); } - mp_uint_t sz = 0; if (*fmt == 's') { - sz = cnt; - cnt = 1; - } - - while (cnt--) { - // If we already have size for 's' case, don't set it again - if (sz == 0) { - sz = (mp_uint_t)mp_binary_get_size(fmt_type, *fmt, &align); - } + size += cnt; + } else { + mp_uint_t align; + size_t sz = mp_binary_get_size(fmt_type, *fmt, &align); if (sz == 0) { nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "unsupported format")); } - // Apply alignment - size = (size + align - 1) & ~(align - 1); - size += sz; - sz = 0; + while (cnt--) { + // Apply alignment + size = (size + align - 1) & ~(align - 1); + size += sz; + } } } return MP_OBJ_NEW_SMALL_INT(size); diff --git a/tests/basics/struct2.py b/tests/basics/struct2.py new file mode 100644 index 0000000000..f438bb55d2 --- /dev/null +++ b/tests/basics/struct2.py @@ -0,0 +1,28 @@ +# test ustruct with a count specified before the type + +try: + import ustruct as struct +except: + import struct + +print(struct.calcsize('0s')) +print(struct.unpack('0s', b'')) +print(struct.pack('0s', b'123')) + +print(struct.calcsize('2s')) +print(struct.unpack('2s', b'12')) +print(struct.pack('2s', b'123')) + +print(struct.calcsize('2H')) +print(struct.unpack('<2H', b'1234')) +print(struct.pack('<2H', 258, 515)) + +print(struct.calcsize('0s1s0H2H')) +print(struct.unpack('<0s1s0H2H', b'01234')) +print(struct.pack('<0s1s0H2H', b'abc', b'abc', 258, 515)) + +# check that zero of an unknown type raises an exception +try: + struct.calcsize('0z') +except: + print('Exception')