net: validate msghdr contents

This commit is contained in:
K. Lange 2024-02-09 15:09:17 +09:00
parent df74cb6b55
commit 60bd809a40
4 changed files with 22 additions and 3 deletions

View File

@ -60,4 +60,4 @@ size_t mmu_used_memory(void);
void * sbrk(size_t); void * sbrk(size_t);
union PML * mmu_get_page_other(union PML * root, uintptr_t virtAddr); union PML * mmu_get_page_other(union PML * root, uintptr_t virtAddr);
int mmu_validate_user_pointer(void * addr, size_t size, int flags); int mmu_validate_user_pointer(const void * addr, size_t size, int flags);

View File

@ -822,7 +822,7 @@ int mmu_copy_on_write(uintptr_t address) {
return 1; return 1;
} }
int mmu_validate_user_pointer(void * addr, size_t size, int flags) { int mmu_validate_user_pointer(const void * addr, size_t size, int flags) {
//printf("mmu_validate_user_pointer(%#zx, %lu, %u);\n", (uintptr_t)addr, size, flags); //printf("mmu_validate_user_pointer(%#zx, %lu, %u);\n", (uintptr_t)addr, size, flags);
if (addr == NULL && !(flags & MMU_PTR_NULL)) return 0; if (addr == NULL && !(flags & MMU_PTR_NULL)) return 0;
if (size > 0x800000000000) return 0; if (size > 0x800000000000) return 0;

View File

@ -1272,7 +1272,7 @@ int mmu_copy_on_write(uintptr_t address) {
* @param flags Control what constitutes a failure. * @param flags Control what constitutes a failure.
* @returns 0 on failure, 1 if process has access. * @returns 0 on failure, 1 if process has access.
*/ */
int mmu_validate_user_pointer(void * addr, size_t size, int flags) { int mmu_validate_user_pointer(const void * addr, size_t size, int flags) {
if (addr == NULL && !(flags & MMU_PTR_NULL)) return 0; if (addr == NULL && !(flags & MMU_PTR_NULL)) return 0;
if (size > 0x800000000000) return 0; if (size > 0x800000000000) return 0;

View File

@ -280,9 +280,27 @@ long net_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
return node->sock_connect(node,addr,addrlen); return node->sock_connect(node,addr,addrlen);
} }
static int validate_msg(const struct msghdr * msg, int readonly) {
int flags = readonly ? 0 : MMU_PTR_WRITE;
if (!mmu_validate_user_pointer(msg,sizeof(struct msghdr),flags)) return 1;
if (msg->msg_iovlen) {
/* Check iovec structures */
if (!mmu_validate_user_pointer(msg->msg_iov, (size_t)(msg->msg_iovlen * sizeof(struct iovec)),flags)) return 1;
/* Check all the buffers in there */
for (size_t i = 0; i < msg->msg_iovlen; ++i) {
if (!mmu_validate_user_pointer(msg->msg_iov[i].iov_base, (size_t)(msg->msg_iov[i].iov_len), flags)) return 1;
}
}
/* Check control message space */
if (msg->msg_controllen && !mmu_validate_user_pointer(msg->msg_control, (size_t)(msg->msg_controllen), flags)) return 1;
return 0;
}
long net_recv(int sockfd, struct msghdr * msg, int flags) { long net_recv(int sockfd, struct msghdr * msg, int flags) {
CHECK_SOCK(sockfd); CHECK_SOCK(sockfd);
PTR_VALIDATE(msg); PTR_VALIDATE(msg);
if (validate_msg(msg,0)) return -EFAULT;
sock_t * node = (sock_t*)FD_ENTRY(sockfd); sock_t * node = (sock_t*)FD_ENTRY(sockfd);
return node->sock_recv(node,msg,flags); return node->sock_recv(node,msg,flags);
} }
@ -290,6 +308,7 @@ long net_recv(int sockfd, struct msghdr * msg, int flags) {
long net_send(int sockfd, const struct msghdr * msg, int flags) { long net_send(int sockfd, const struct msghdr * msg, int flags) {
CHECK_SOCK(sockfd); CHECK_SOCK(sockfd);
PTR_VALIDATE(msg); PTR_VALIDATE(msg);
if (validate_msg(msg,1)) return -EFAULT;
sock_t * node = (sock_t*)FD_ENTRY(sockfd); sock_t * node = (sock_t*)FD_ENTRY(sockfd);
return node->sock_send(node,msg,flags); return node->sock_send(node,msg,flags);
} }