diff --git a/py/objstr.c b/py/objstr.c index 33bfcc3756..7549dedb7b 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -357,7 +357,8 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { } STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) { - assert(MP_OBJ_IS_STR(self_in)); + assert(is_str_or_bytes(self_in)); + const mp_obj_type_t *self_type = mp_obj_get_type(self_in); // get separation string GET_STR_DATA_LEN(self_in, sep_str, sep_len); @@ -379,8 +380,9 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) { // count required length int required_len = 0; for (int i = 0; i < seq_len; i++) { - if (!MP_OBJ_IS_STR(seq_items[i])) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "join expected a list of str's")); + if (mp_obj_get_type(seq_items[i]) != self_type) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, + "join expects a list of str/bytes objects consistent with self object")); } if (i > 0) { required_len += sep_len; @@ -391,7 +393,7 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) { // make joined string byte *data; - mp_obj_t joined_str = mp_obj_str_builder_start(mp_obj_get_type(self_in), required_len, &data); + mp_obj_t joined_str = mp_obj_str_builder_start(self_type, required_len, &data); for (int i = 0; i < seq_len; i++) { if (i > 0) { memcpy(data, sep_str, sep_len); diff --git a/tests/basics/string-join.py b/tests/basics/string-join.py index 275a804c64..49bbfc5ca0 100644 --- a/tests/basics/string-join.py +++ b/tests/basics/string-join.py @@ -10,3 +10,15 @@ print(''.join('')) print(''.join('abc')) print(','.join('abc')) print(','.join('abc' for i in range(5))) + +print(b','.join([b'abc', b'123'])) + +try: + print(b','.join(['abc', b'123'])) +except TypeError: + print("TypeError") + +try: + print(','.join([b'abc', b'123'])) +except TypeError: + print("TypeError")