network/stack: Clean up socket_receive.

* Reshuffle variable declarations, most are placed much
   closer to first usages now.

 * Turn a confusingly worded comment into an ASSERT(),
   and remove another one that was outdated.

 * Fix some minor code style problems.

 * Make the copying logic more consistent between first
   and then subsequent copies.

 * Make it possible for B_BAD_ADDRESS (EFAULT) to be
   returned. This is not listed in POSIX, but as per
   online sources, at least Linux does do this.

Change-Id: Idcfbed30531c1ab4796c4ee37f7f4ce8078e535b
Reviewed-on: https://review.haiku-os.org/c/haiku/+/7147
Reviewed-by: waddlesplash <waddlesplash@gmail.com>
Reviewed-by: Jérôme Duval <jerome.duval@gmail.com>
Tested-by: Commit checker robot <no-reply+buildbot@haiku-os.org>
This commit is contained in:
Augustin Cavalier 2023-11-24 23:37:41 -05:00 committed by waddlesplash
parent 111528d18d
commit 783aa308c3

View File

@ -1162,6 +1162,8 @@ ssize_t
socket_receive(net_socket* socket, msghdr* header, void* data, size_t length, socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
int flags) int flags)
{ {
const int originalFlags = flags;
// MSG_NOSIGNAL is only meaningful for send(), not receive(), but it is // MSG_NOSIGNAL is only meaningful for send(), not receive(), but it is
// sometimes specified anyway. Mask it off to avoid unnecessary errors. // sometimes specified anyway. Mask it off to avoid unnecessary errors.
flags &= ~MSG_NOSIGNAL; flags &= ~MSG_NOSIGNAL;
@ -1170,22 +1172,19 @@ socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
if (socket->first_info->read_data_no_buffer != NULL) if (socket->first_info->read_data_no_buffer != NULL)
return socket_receive_no_buffer(socket, header, data, length, flags); return socket_receive_no_buffer(socket, header, data, length, flags);
const int originalFlags = flags; // Mask off flags handled in this function.
flags &= ~MSG_TRUNC; flags &= ~(MSG_TRUNC);
size_t totalLength = length; size_t totalLength = length;
net_buffer* buffer; if (header != NULL) {
int i; ASSERT(data == header->msg_iov[0].iov_base);
// the convention to this function is that have header been
// present, { data, length } would have been iovec[0] and is
// always considered like that
if (header) {
// calculate the length considering all of the extra buffers // calculate the length considering all of the extra buffers
for (i = 1; i < header->msg_iovlen; i++) for (int i = 1; i < header->msg_iovlen; i++)
totalLength += header->msg_iov[i].iov_len; totalLength += header->msg_iov[i].iov_len;
} }
net_buffer* buffer;
status_t status = socket->first_info->read_data( status_t status = socket->first_info->read_data(
socket->first_protocol, totalLength, flags, &buffer); socket->first_protocol, totalLength, flags, &buffer);
if (status != B_OK) if (status != B_OK)
@ -1210,11 +1209,9 @@ socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
// TODO: - returning a NULL buffer when received 0 bytes // TODO: - returning a NULL buffer when received 0 bytes
// may not make much sense as we still need the address // may not make much sense as we still need the address
// - gNetBufferModule.read() uses memcpy() instead of user_memcpy
size_t nameLen = 0; size_t nameLen = 0;
if (header != NULL) {
if (header) {
// TODO: - consider the control buffer options // TODO: - consider the control buffer options
nameLen = header->msg_namelen; nameLen = header->msg_namelen;
header->msg_namelen = 0; header->msg_namelen = 0;
@ -1224,24 +1221,27 @@ socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
if (buffer == NULL) if (buffer == NULL)
return 0; return 0;
size_t bytesReceived = buffer->size, bytesCopied = 0; const size_t bytesReceived = buffer->size;
size_t bytesCopied = 0;
length = min_c(bytesReceived, length); size_t toRead = min_c(bytesReceived, length);
if (gNetBufferModule.read(buffer, 0, data, length) < B_OK) { status = gNetBufferModule.read(buffer, 0, data, toRead);
if (status != B_OK) {
gNetBufferModule.free(buffer); gNetBufferModule.free(buffer);
if (status == B_BAD_ADDRESS)
return status;
return ENOBUFS; return ENOBUFS;
} }
// if first copy was a success, proceed to following // if first copy was a success, proceed to following copies as required
// copies as required bytesCopied += toRead;
bytesCopied += length;
if (header) { if (header != NULL) {
// we only start considering at iovec[1] // We start at iovec[1] as { data, length } is iovec[0].
// as { data, length } is iovec[0] for (int i = 1; i < header->msg_iovlen && bytesCopied < bytesReceived; i++) {
for (i = 1; i < header->msg_iovlen && bytesCopied < bytesReceived; i++) {
iovec& vec = header->msg_iov[i]; iovec& vec = header->msg_iov[i];
size_t toRead = min_c(bytesReceived - bytesCopied, vec.iov_len); toRead = min_c(bytesReceived - bytesCopied, vec.iov_len);
if (gNetBufferModule.read(buffer, bytesCopied, vec.iov_base, if (gNetBufferModule.read(buffer, bytesCopied, vec.iov_base,
toRead) < B_OK) { toRead) < B_OK) {
break; break;
@ -1259,7 +1259,7 @@ socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
gNetBufferModule.free(buffer); gNetBufferModule.free(buffer);
if (bytesCopied < bytesReceived) { if (bytesCopied < bytesReceived) {
if (header) if (header != NULL)
header->msg_flags = MSG_TRUNC; header->msg_flags = MSG_TRUNC;
if ((originalFlags & MSG_TRUNC) != 0) if ((originalFlags & MSG_TRUNC) != 0)
@ -1274,13 +1274,10 @@ ssize_t
socket_send(net_socket* socket, msghdr* header, const void* data, size_t length, socket_send(net_socket* socket, msghdr* header, const void* data, size_t length,
int flags) int flags)
{ {
const sockaddr* address = NULL;
socklen_t addressLength = 0;
size_t bytesLeft = length;
const bool nosignal = ((flags & MSG_NOSIGNAL) != 0); const bool nosignal = ((flags & MSG_NOSIGNAL) != 0);
flags &= ~MSG_NOSIGNAL; flags &= ~MSG_NOSIGNAL;
size_t bytesLeft = length;
if (length > SSIZE_MAX) if (length > SSIZE_MAX)
return B_BAD_VALUE; return B_BAD_VALUE;
@ -1289,6 +1286,8 @@ socket_send(net_socket* socket, msghdr* header, const void* data, size_t length,
ancillary_data_container, void, delete_ancillary_data_container> ancillary_data_container, void, delete_ancillary_data_container>
ancillaryDataDeleter; ancillaryDataDeleter;
const sockaddr* address = NULL;
socklen_t addressLength = 0;
if (header != NULL) { if (header != NULL) {
address = (const sockaddr*)header->msg_name; address = (const sockaddr*)header->msg_name;
addressLength = header->msg_namelen; addressLength = header->msg_namelen;