diff --git a/py/objstr.c b/py/objstr.c index 3e275afb9a..e0cbea577c 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -375,33 +375,58 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { return mp_obj_new_str_of_type(type, self_data + slice.start, slice.stop - slice.start); } #endif - // TODO: Don't use mp_get_index() here - uint index_val = mp_get_index(type, unichar_charlen((const char *)self_data, self_len), index, false); if (type == &mp_type_bytes) { + uint index_val = mp_get_index(type, self_len, index, false); return MP_OBJ_NEW_SMALL_INT((mp_small_int_t)self_data[index_val]); - } else { - // Count non-continuation bytes to count characters. - // Assumes that the string is correctly formed - will run past the - // end of the buffer if there aren't that many characters in it - const char *s; - for (s=(const char *)self_data; index_val; ++s) { + } + const char *s, *top = (const char *)self_data + self_len; + machine_int_t i; + // Copied from mp_get_index; I don't want bounds checking, just give me + // the integer as-is. (I can't bounds-check without scanning the whole + // string; an out-of-bounds index will be caught in the loops below.) + if (MP_OBJ_IS_SMALL_INT(index)) { + i = MP_OBJ_SMALL_INT_VALUE(index); + } else if (!mp_obj_get_int_maybe(index, &i)) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "%s indices must be integers, not %s", qstr_str(type->name), mp_obj_get_type_str(index))); + } + if (i < 0) + { + // Negative indexing is performed by counting from the end of the string. + for (s = top - 1; i; --s) { + if (s < (const char *)self_data) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_IndexError, "string index out of range")); + } if (!UTF8_IS_CONT(*s)) { - --index_val; + ++i; + } + } + ++s; + } else { + // Positive indexing, correspondingly, counts from the start of the string. + // It's assumed that negative indexing will generally be used with small + // absolute values (eg str[-1], not str[-1000000]), which means it'll be + // more efficient this way. + for (s = (const char *)self_data; i; ++s) { + if (s >= top) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_IndexError, "string index out of range")); + } + if (!UTF8_IS_CONT(*s)) { + --i; } } // Skip continuation bytes after the last lead byte while (UTF8_IS_CONT(*s)) { ++s; } - int len = 1; - if (UTF8_IS_NONASCII(*s)) { - // Count the number of 1 bits (after the first) - for (char mask = 0x40; *s & mask; mask >>= 1) { - ++len; - } - } - return mp_obj_new_str(s, len, true); // This will create a one-character string } + int len = 1; + if (UTF8_IS_NONASCII(*s)) { + // Count the number of 1 bits (after the first) + for (char mask = 0x40; *s & mask; mask >>= 1) { + ++len; + } + } + return mp_obj_new_str(s, len, true); // This will create a one-character string } else { return MP_OBJ_NULL; // op not supported } diff --git a/tests/basics/unicode.py b/tests/basics/unicode.py index ee66679b8d..777525e06f 100644 --- a/tests/basics/unicode.py +++ b/tests/basics/unicode.py @@ -5,5 +5,5 @@ for i in range(len(s)): # Test all three forms of Unicode escape, and # all blocks of UTF-8 byte patterns s = "a\xA9\xFF\u0123\u0800\uFFEE\U0001F44C" -for i in range(len(s)): +for i in range(-len(s), len(s)): print("s[%d]: %s %X"%(i, s[i], ord(s[i])))