diff --git a/extmod/modwebrepl.c b/extmod/modwebrepl.c index c160abea2f..e9e60b7b41 100644 --- a/extmod/modwebrepl.c +++ b/extmod/modwebrepl.c @@ -250,8 +250,8 @@ STATIC mp_uint_t _webrepl_read(mp_obj_t self_in, void *buf, mp_uint_t size, int DEBUG_printf("webrepl: Writing %lu bytes to file\n", buf_sz); int err; - mp_uint_t res = mp_stream_writeall(self->cur_file, filebuf, buf_sz, &err); - if(res == MP_STREAM_ERROR) { + mp_uint_t res = mp_stream_write_exactly(self->cur_file, filebuf, buf_sz, &err); + if (err != 0 || res != buf_sz) { assert(0); } diff --git a/extmod/modwebsocket.c b/extmod/modwebsocket.c index 344933ded3..949a5fed91 100644 --- a/extmod/modwebsocket.c +++ b/extmod/modwebsocket.c @@ -240,9 +240,9 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si mp_call_method_n_kw(1, 0, dest); } - mp_uint_t out_sz = mp_stream_writeall(self->sock, header, hdr_sz, errcode); - if (out_sz != MP_STREAM_ERROR) { - out_sz = mp_stream_writeall(self->sock, buf, size, errcode); + mp_uint_t out_sz = mp_stream_write_exactly(self->sock, header, hdr_sz, errcode); + if (*errcode == 0) { + out_sz = mp_stream_write_exactly(self->sock, buf, size, errcode); } if (self->opts & BLOCKING_WRITE) { @@ -250,6 +250,9 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si mp_call_method_n_kw(1, 0, dest); } + if (*errcode != 0) { + return MP_STREAM_ERROR; + } return out_sz; } diff --git a/py/modio.c b/py/modio.c index 423315081f..2fbe6bc1e1 100644 --- a/py/modio.c +++ b/py/modio.c @@ -78,10 +78,13 @@ STATIC mp_uint_t bufwriter_write(mp_obj_t self_in, const void *buf, mp_uint_t si memcpy(self->buf + self->len, buf, rem); buf = (byte*)buf + rem; size -= rem; - mp_uint_t out_sz = mp_stream_writeall(self->stream, self->buf, self->alloc, errcode); - if (out_sz == MP_STREAM_ERROR) { + mp_uint_t out_sz = mp_stream_write_exactly(self->stream, self->buf, self->alloc, errcode); + if (*errcode != 0) { return MP_STREAM_ERROR; } + // TODO: try to recover from a case of non-blocking stream, e.g. move + // remaining chunk to the beginning of buffer. + assert(out_sz == self->alloc); self->len = 0; } @@ -93,9 +96,12 @@ STATIC mp_obj_t bufwriter_flush(mp_obj_t self_in) { if (self->len != 0) { int err; - mp_uint_t out_sz = mp_stream_writeall(self->stream, self->buf, self->len, &err); + mp_uint_t out_sz = mp_stream_write_exactly(self->stream, self->buf, self->len, &err); + // TODO: try to recover from a case of non-blocking stream, e.g. move + // remaining chunk to the beginning of buffer. + assert(out_sz == self->len); self->len = 0; - if (out_sz == MP_STREAM_ERROR) { + if (err != 0) { nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(err))); } } diff --git a/py/stream.c b/py/stream.c index a3df1b8fdd..9b1d5fd2de 100644 --- a/py/stream.c +++ b/py/stream.c @@ -49,6 +49,48 @@ STATIC mp_obj_t stream_readall(mp_obj_t self_in); #define STREAM_CONTENT_TYPE(stream) (((stream)->is_text) ? &mp_type_str : &mp_type_bytes) +// Returns error condition in *errcode, if non-zero, return value is number of bytes written +// before error condition occured. If *errcode == 0, returns total bytes written (which will +// be equal to input size). +mp_uint_t mp_stream_rw(mp_obj_t stream, void *buf_, mp_uint_t size, int *errcode, byte flags) { + byte *buf = buf_; + mp_obj_base_t* s = (mp_obj_base_t*)MP_OBJ_TO_PTR(stream); + typedef mp_uint_t (*io_func_t)(mp_obj_t obj, void *buf, mp_uint_t size, int *errcode); + io_func_t io_func; + if (flags & MP_STREAM_RW_WRITE) { + io_func = (io_func_t)s->type->stream_p->write; + } else { + io_func = s->type->stream_p->read; + } + + *errcode = 0; + mp_uint_t done = 0; + while (size > 0) { + mp_uint_t out_sz = io_func(stream, buf, size, errcode); + // For read, out_sz == 0 means EOF. For write, it's unspecified + // what it means, but we don't make any progress, so returning + // is still the best option. + if (out_sz == 0) { + return done; + } + if (out_sz == MP_STREAM_ERROR) { + // If we read something before getting EAGAIN, don't leak it + if (mp_is_nonblocking_error(*errcode) && done != 0) { + *errcode = 0; + } + return done; + } + if (flags & MP_STREAM_RW_ONCE) { + return out_sz; + } + + buf += out_sz; + size -= out_sz; + done += out_sz; + } + return done; +} + const mp_stream_p_t *mp_get_stream_raise(mp_obj_t self_in, int flags) { mp_obj_base_t *o = (mp_obj_base_t*)MP_OBJ_TO_PTR(self_in); const mp_stream_p_t *stream_p = o->type->stream_p; @@ -62,7 +104,7 @@ const mp_stream_p_t *mp_get_stream_raise(mp_obj_t self_in, int flags) { return stream_p; } -STATIC mp_obj_t stream_read(size_t n_args, const mp_obj_t *args) { +STATIC mp_obj_t stream_read_generic(size_t n_args, const mp_obj_t *args, byte flags) { const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_READ); // What to do if sz < -1? Python docs don't specify this case. @@ -94,8 +136,8 @@ STATIC mp_obj_t stream_read(size_t n_args, const mp_obj_t *args) { nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_MemoryError, "out of memory")); } int error; - mp_uint_t out_sz = stream_p->read(args[0], p, more_bytes, &error); - if (out_sz == MP_STREAM_ERROR) { + mp_uint_t out_sz = mp_stream_read_exactly(args[0], p, more_bytes, &error); + if (error != 0) { vstr_cut_tail_bytes(&vstr, more_bytes); if (mp_is_nonblocking_error(error)) { // With non-blocking streams, we read as much as we can. @@ -165,8 +207,8 @@ STATIC mp_obj_t stream_read(size_t n_args, const mp_obj_t *args) { vstr_t vstr; vstr_init_len(&vstr, sz); int error; - mp_uint_t out_sz = stream_p->read(args[0], vstr.buf, sz, &error); - if (out_sz == MP_STREAM_ERROR) { + mp_uint_t out_sz = mp_stream_rw(args[0], vstr.buf, sz, &error, flags); + if (error != 0) { vstr_clear(&vstr); if (mp_is_nonblocking_error(error)) { // https://docs.python.org/3.4/library/io.html#io.RawIOBase.read @@ -182,20 +224,27 @@ STATIC mp_obj_t stream_read(size_t n_args, const mp_obj_t *args) { return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(stream_p), &vstr); } } + +STATIC mp_obj_t stream_read(size_t n_args, const mp_obj_t *args) { + return stream_read_generic(n_args, args, MP_STREAM_RW_READ); +} MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_stream_read_obj, 1, 2, stream_read); -mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len) { - const mp_stream_p_t *stream_p = mp_get_stream_raise(self_in, MP_STREAM_OP_WRITE); +STATIC mp_obj_t stream_read1(size_t n_args, const mp_obj_t *args) { + return stream_read_generic(n_args, args, MP_STREAM_RW_READ | MP_STREAM_RW_ONCE); +} +MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_stream_read1_obj, 1, 2, stream_read1); + +mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len, byte flags) { + mp_get_stream_raise(self_in, MP_STREAM_OP_WRITE); int error; - mp_uint_t out_sz = stream_p->write(self_in, buf, len, &error); - if (out_sz == MP_STREAM_ERROR) { + mp_uint_t out_sz = mp_stream_rw(self_in, (void*)buf, len, &error, flags); + if (error != 0) { if (mp_is_nonblocking_error(error)) { // http://docs.python.org/3/library/io.html#io.RawIOBase.write // "None is returned if the raw stream is set not to block and // no single byte could be readily written to it." - // This is for consistency with read() behavior, still weird, - // see abobe. return mp_const_none; } nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(error))); @@ -206,33 +255,25 @@ mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len) { // XXX hack void mp_stream_write_adaptor(void *self, const char *buf, size_t len) { - mp_stream_write(MP_OBJ_FROM_PTR(self), buf, len); -} - -// Works only with blocking streams -mp_uint_t mp_stream_writeall(mp_obj_t stream, const byte *buf, mp_uint_t size, int *errcode) { - mp_obj_base_t* s = (mp_obj_base_t*)MP_OBJ_TO_PTR(stream); - mp_uint_t org_size = size; - while (size > 0) { - mp_uint_t out_sz = s->type->stream_p->write(stream, buf, size, errcode); - if (out_sz == MP_STREAM_ERROR) { - return MP_STREAM_ERROR; - } - buf += out_sz; - size -= out_sz; - } - return org_size; + mp_stream_write(MP_OBJ_FROM_PTR(self), buf, len, MP_STREAM_RW_WRITE); } STATIC mp_obj_t stream_write_method(mp_obj_t self_in, mp_obj_t arg) { mp_buffer_info_t bufinfo; mp_get_buffer_raise(arg, &bufinfo, MP_BUFFER_READ); - return mp_stream_write(self_in, bufinfo.buf, bufinfo.len); + return mp_stream_write(self_in, bufinfo.buf, bufinfo.len, MP_STREAM_RW_WRITE); } MP_DEFINE_CONST_FUN_OBJ_2(mp_stream_write_obj, stream_write_method); +STATIC mp_obj_t stream_write1_method(mp_obj_t self_in, mp_obj_t arg) { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(arg, &bufinfo, MP_BUFFER_READ); + return mp_stream_write(self_in, bufinfo.buf, bufinfo.len, MP_STREAM_RW_WRITE | MP_STREAM_RW_ONCE); +} +MP_DEFINE_CONST_FUN_OBJ_2(mp_stream_write1_obj, stream_write1_method); + STATIC mp_obj_t stream_readinto(size_t n_args, const mp_obj_t *args) { - const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_READ); + mp_get_stream_raise(args[0], MP_STREAM_OP_READ); mp_buffer_info_t bufinfo; mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_WRITE); @@ -248,8 +289,8 @@ STATIC mp_obj_t stream_readinto(size_t n_args, const mp_obj_t *args) { } int error; - mp_uint_t out_sz = stream_p->read(args[0], bufinfo.buf, len, &error); - if (out_sz == MP_STREAM_ERROR) { + mp_uint_t out_sz = mp_stream_read_exactly(args[0], bufinfo.buf, len, &error); + if (error != 0) { if (mp_is_nonblocking_error(error)) { return mp_const_none; } diff --git a/py/stream.h b/py/stream.h index df6e94adfd..9202c64f31 100644 --- a/py/stream.h +++ b/py/stream.h @@ -48,11 +48,13 @@ struct mp_stream_seek_t { }; MP_DECLARE_CONST_FUN_OBJ(mp_stream_read_obj); +MP_DECLARE_CONST_FUN_OBJ(mp_stream_read1_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_readinto_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_readall_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_unbuffered_readline_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_unbuffered_readlines_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_write_obj); +MP_DECLARE_CONST_FUN_OBJ(mp_stream_write1_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_seek_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_tell_obj); MP_DECLARE_CONST_FUN_OBJ(mp_stream_ioctl_obj); @@ -67,10 +69,15 @@ const mp_stream_p_t *mp_get_stream_raise(mp_obj_t self_in, int flags); // Iterator which uses mp_stream_unbuffered_readline_obj mp_obj_t mp_stream_unbuffered_iter(mp_obj_t self); -mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len); +mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len, byte flags); -// Helper function to write entire buf to *blocking* stream -mp_uint_t mp_stream_writeall(mp_obj_t stream, const byte *buf, mp_uint_t size, int *errcode); +// C-level helper functions +#define MP_STREAM_RW_READ 0 +#define MP_STREAM_RW_WRITE 2 +#define MP_STREAM_RW_ONCE 1 +mp_uint_t mp_stream_rw(mp_obj_t stream, void *buf, mp_uint_t size, int *errcode, byte flags); +#define mp_stream_write_exactly(stream, buf, size, err) mp_stream_rw(stream, (byte*)buf, size, err, MP_STREAM_RW_WRITE) +#define mp_stream_read_exactly(stream, buf, size, err) mp_stream_rw(stream, buf, size, err, MP_STREAM_RW_READ) #if MICROPY_STREAMS_NON_BLOCK // TODO: This is POSIX-specific (but then POSIX is the only real thing,