diff --git a/extmod/modwebsocket.c b/extmod/modwebsocket.c index 6cd4f515df..9b0c19d6d5 100644 --- a/extmod/modwebsocket.c +++ b/extmod/modwebsocket.c @@ -37,6 +37,7 @@ #if MICROPY_PY_WEBSOCKET enum { FRAME_HEADER, FRAME_OPT, PAYLOAD }; +enum { BLOCKING_WRITE = 1 }; typedef struct _mp_obj_websocket_t { mp_obj_base_t base; @@ -48,10 +49,11 @@ typedef struct _mp_obj_websocket_t { byte mask_pos; byte buf_pos; byte buf[6]; + byte opts; } mp_obj_websocket_t; STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { - mp_arg_check_num(n_args, n_kw, 1, 1, false); + mp_arg_check_num(n_args, n_kw, 1, 2, false); mp_obj_websocket_t *o = m_new_obj(mp_obj_websocket_t); o->base.type = type; o->sock = args[0]; @@ -59,6 +61,10 @@ STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, siz o->to_recv = 2; o->mask_pos = 0; o->buf_pos = 0; + o->opts = 0; + if (n_args > 1 && args[1] == mp_const_true) { + o->opts |= BLOCKING_WRITE; + } return o; } @@ -157,11 +163,24 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si assert(size < 126); byte header[] = {0x81, size}; - mp_uint_t out_sz = mp_stream_writeall(self->sock, header, sizeof(header), errcode); - if (out_sz == MP_STREAM_ERROR) { - return MP_STREAM_ERROR; + mp_obj_t dest[3]; + if (self->opts & BLOCKING_WRITE) { + mp_load_method(self->sock, MP_QSTR_setblocking, dest); + dest[2] = mp_const_true; + mp_call_method_n_kw(1, 0, dest); } - return mp_stream_writeall(self->sock, buf, size, errcode); + + mp_uint_t out_sz = mp_stream_writeall(self->sock, header, sizeof(header), errcode); + if (out_sz != MP_STREAM_ERROR) { + out_sz = mp_stream_writeall(self->sock, buf, size, errcode); + } + + if (self->opts & BLOCKING_WRITE) { + dest[2] = mp_const_false; + mp_call_method_n_kw(1, 0, dest); + } + + return out_sz; } STATIC const mp_map_elem_t websocket_locals_dict_table[] = {