qemu/subprojects/libvhost-user/libvhost-user.c
Stefano Garzarella 92b58bc7e9 libvhost-user: fail vu_message_write() if sendmsg() is failing
In vu_message_write() we use sendmsg() to send the message header,
then a write() to send the payload.

If sendmsg() fails we should avoid sending the payload, since we
were unable to send the header.

Discovered before fixing the issue with the previous patch, where
sendmsg() failed on macOS due to wrong parameters, but the frontend
still sent the payload which the backend incorrectly interpreted
as a wrong header.

Reviewed-by: Daniel P. Berrangé <berrange@redhat.com>
Reviewed-by: Philippe Mathieu-Daudé <philmd@linaro.org>
Tested-by: Philippe Mathieu-Daudé <philmd@linaro.org>
Acked-by: Stefan Hajnoczi <stefanha@redhat.com>
Reviewed-by: David Hildenbrand <david@redhat.com>
Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
Message-Id: <20240618100043.144657-4-sgarzare@redhat.com>
Reviewed-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
2024-07-02 09:27:56 -04:00

3140 lines
84 KiB
C

/*
* Vhost User library
*
* Copyright IBM, Corp. 2007
* Copyright (c) 2016 Red Hat, Inc.
*
* Authors:
* Anthony Liguori <aliguori@us.ibm.com>
* Marc-André Lureau <mlureau@redhat.com>
* Victor Kaplansky <victork@redhat.com>
*
* This work is licensed under the terms of the GNU GPL, version 2 or
* later. See the COPYING file in the top-level directory.
*/
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
/* this code avoids GLib dependency */
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <stdarg.h>
#include <errno.h>
#include <string.h>
#include <assert.h>
#include <inttypes.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/eventfd.h>
#include <sys/mman.h>
#include <endian.h>
/* Necessary to provide VIRTIO_F_VERSION_1 on system
* with older linux headers. Must appear before
* <linux/vhost.h> below.
*/
#include "standard-headers/linux/virtio_config.h"
#if defined(__linux__)
#include <sys/syscall.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <linux/vhost.h>
#include <sys/vfs.h>
#include <linux/magic.h>
#ifdef __NR_userfaultfd
#include <linux/userfaultfd.h>
#endif
#endif
#include "include/atomic.h"
#include "libvhost-user.h"
/* usually provided by GLib */
#if __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
#if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
#define G_GNUC_PRINTF(format_idx, arg_idx) \
__attribute__((__format__(gnu_printf, format_idx, arg_idx)))
#else
#define G_GNUC_PRINTF(format_idx, arg_idx) \
__attribute__((__format__(__printf__, format_idx, arg_idx)))
#endif
#else /* !__GNUC__ */
#define G_GNUC_PRINTF(format_idx, arg_idx)
#endif /* !__GNUC__ */
#ifndef MIN
#define MIN(x, y) ({ \
__typeof__(x) _min1 = (x); \
__typeof__(y) _min2 = (y); \
(void) (&_min1 == &_min2); \
_min1 < _min2 ? _min1 : _min2; })
#endif
/* Round number down to multiple */
#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
/* Round number up to multiple */
#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
#ifndef unlikely
#define unlikely(x) __builtin_expect(!!(x), 0)
#endif
/* Align each region to cache line size in inflight buffer */
#define INFLIGHT_ALIGNMENT 64
/* The version of inflight buffer */
#define INFLIGHT_VERSION 1
/* The version of the protocol we support */
#define VHOST_USER_VERSION 1
#define LIBVHOST_USER_DEBUG 0
#define DPRINT(...) \
do { \
if (LIBVHOST_USER_DEBUG) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
static inline
bool has_feature(uint64_t features, unsigned int fbit)
{
assert(fbit < 64);
return !!(features & (1ULL << fbit));
}
static inline
bool vu_has_feature(VuDev *dev,
unsigned int fbit)
{
return has_feature(dev->features, fbit);
}
static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
{
return has_feature(dev->protocol_features, fbit);
}
const char *
vu_request_to_string(unsigned int req)
{
#define REQ(req) [req] = #req
static const char *vu_request_str[] = {
REQ(VHOST_USER_NONE),
REQ(VHOST_USER_GET_FEATURES),
REQ(VHOST_USER_SET_FEATURES),
REQ(VHOST_USER_SET_OWNER),
REQ(VHOST_USER_RESET_OWNER),
REQ(VHOST_USER_SET_MEM_TABLE),
REQ(VHOST_USER_SET_LOG_BASE),
REQ(VHOST_USER_SET_LOG_FD),
REQ(VHOST_USER_SET_VRING_NUM),
REQ(VHOST_USER_SET_VRING_ADDR),
REQ(VHOST_USER_SET_VRING_BASE),
REQ(VHOST_USER_GET_VRING_BASE),
REQ(VHOST_USER_SET_VRING_KICK),
REQ(VHOST_USER_SET_VRING_CALL),
REQ(VHOST_USER_SET_VRING_ERR),
REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
REQ(VHOST_USER_GET_QUEUE_NUM),
REQ(VHOST_USER_SET_VRING_ENABLE),
REQ(VHOST_USER_SEND_RARP),
REQ(VHOST_USER_NET_SET_MTU),
REQ(VHOST_USER_SET_BACKEND_REQ_FD),
REQ(VHOST_USER_IOTLB_MSG),
REQ(VHOST_USER_SET_VRING_ENDIAN),
REQ(VHOST_USER_GET_CONFIG),
REQ(VHOST_USER_SET_CONFIG),
REQ(VHOST_USER_POSTCOPY_ADVISE),
REQ(VHOST_USER_POSTCOPY_LISTEN),
REQ(VHOST_USER_POSTCOPY_END),
REQ(VHOST_USER_GET_INFLIGHT_FD),
REQ(VHOST_USER_SET_INFLIGHT_FD),
REQ(VHOST_USER_GPU_SET_SOCKET),
REQ(VHOST_USER_VRING_KICK),
REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
REQ(VHOST_USER_ADD_MEM_REG),
REQ(VHOST_USER_REM_MEM_REG),
REQ(VHOST_USER_GET_SHARED_OBJECT),
REQ(VHOST_USER_MAX),
};
#undef REQ
if (req < VHOST_USER_MAX) {
return vu_request_str[req];
} else {
return "unknown";
}
}
static void G_GNUC_PRINTF(2, 3)
vu_panic(VuDev *dev, const char *msg, ...)
{
char *buf = NULL;
va_list ap;
va_start(ap, msg);
if (vasprintf(&buf, msg, ap) < 0) {
buf = NULL;
}
va_end(ap);
dev->broken = true;
dev->panic(dev, buf);
free(buf);
/*
* FIXME:
* find a way to call virtio_error, or perhaps close the connection?
*/
}
/* Search for a memory region that covers this guest physical address. */
static VuDevRegion *
vu_gpa_to_mem_region(VuDev *dev, uint64_t guest_addr)
{
int low = 0;
int high = dev->nregions - 1;
/*
* Memory regions cannot overlap in guest physical address space. Each
* GPA belongs to exactly one memory region, so there can only be one
* match.
*
* We store our memory regions ordered by GPA and can simply perform a
* binary search.
*/
while (low <= high) {
unsigned int mid = low + (high - low) / 2;
VuDevRegion *cur = &dev->regions[mid];
if (guest_addr >= cur->gpa && guest_addr < cur->gpa + cur->size) {
return cur;
}
if (guest_addr >= cur->gpa + cur->size) {
low = mid + 1;
}
if (guest_addr < cur->gpa) {
high = mid - 1;
}
}
return NULL;
}
/* Translate guest physical address to our virtual address. */
void *
vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
{
VuDevRegion *r;
if (*plen == 0) {
return NULL;
}
r = vu_gpa_to_mem_region(dev, guest_addr);
if (!r) {
return NULL;
}
if ((guest_addr + *plen) > (r->gpa + r->size)) {
*plen = r->gpa + r->size - guest_addr;
}
return (void *)(uintptr_t)guest_addr - r->gpa + r->mmap_addr +
r->mmap_offset;
}
/* Translate qemu virtual address to our virtual address. */
static void *
qva_to_va(VuDev *dev, uint64_t qemu_addr)
{
unsigned int i;
/* Find matching memory region. */
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
return (void *)(uintptr_t)
qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
}
}
return NULL;
}
static void
vu_remove_all_mem_regs(VuDev *dev)
{
unsigned int i;
for (i = 0; i < dev->nregions; i++) {
VuDevRegion *r = &dev->regions[i];
munmap((void *)(uintptr_t)r->mmap_addr, r->size + r->mmap_offset);
}
dev->nregions = 0;
}
static bool
map_ring(VuDev *dev, VuVirtq *vq)
{
vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
DPRINT("Setting virtq addresses:\n");
DPRINT(" vring_desc at %p\n", vq->vring.desc);
DPRINT(" vring_used at %p\n", vq->vring.used);
DPRINT(" vring_avail at %p\n", vq->vring.avail);
return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
}
static bool
vu_is_vq_usable(VuDev *dev, VuVirtq *vq)
{
if (unlikely(dev->broken)) {
return false;
}
if (likely(vq->vring.avail)) {
return true;
}
/*
* In corner cases, we might temporarily remove a memory region that
* mapped a ring. When removing a memory region we make sure to
* unmap any rings that would be impacted. Let's try to remap if we
* already succeeded mapping this ring once.
*/
if (!vq->vra.desc_user_addr || !vq->vra.used_user_addr ||
!vq->vra.avail_user_addr) {
return false;
}
if (map_ring(dev, vq)) {
vu_panic(dev, "remapping queue on access");
return false;
}
return true;
}
static void
unmap_rings(VuDev *dev, VuDevRegion *r)
{
int i;
for (i = 0; i < dev->max_queues; i++) {
VuVirtq *vq = &dev->vq[i];
const uintptr_t desc = (uintptr_t)vq->vring.desc;
const uintptr_t used = (uintptr_t)vq->vring.used;
const uintptr_t avail = (uintptr_t)vq->vring.avail;
if (desc < r->mmap_addr || desc >= r->mmap_addr + r->size) {
continue;
}
if (used < r->mmap_addr || used >= r->mmap_addr + r->size) {
continue;
}
if (avail < r->mmap_addr || avail >= r->mmap_addr + r->size) {
continue;
}
DPRINT("Unmapping rings of queue %d\n", i);
vq->vring.desc = NULL;
vq->vring.used = NULL;
vq->vring.avail = NULL;
}
}
static size_t
get_fd_hugepagesize(int fd)
{
#if defined(__linux__)
struct statfs fs;
int ret;
do {
ret = fstatfs(fd, &fs);
} while (ret != 0 && errno == EINTR);
if (!ret && (unsigned int)fs.f_type == HUGETLBFS_MAGIC) {
return fs.f_bsize;
}
#endif
return 0;
}
static void
_vu_add_mem_reg(VuDev *dev, VhostUserMemoryRegion *msg_region, int fd)
{
const uint64_t start_gpa = msg_region->guest_phys_addr;
const uint64_t end_gpa = start_gpa + msg_region->memory_size;
int prot = PROT_READ | PROT_WRITE;
uint64_t mmap_offset, fd_offset;
size_t hugepagesize;
VuDevRegion *r;
void *mmap_addr;
int low = 0;
int high = dev->nregions - 1;
unsigned int idx;
DPRINT("Adding region %d\n", dev->nregions);
DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n",
msg_region->guest_phys_addr);
DPRINT(" memory_size: 0x%016"PRIx64"\n",
msg_region->memory_size);
DPRINT(" userspace_addr: 0x%016"PRIx64"\n",
msg_region->userspace_addr);
DPRINT(" old mmap_offset: 0x%016"PRIx64"\n",
msg_region->mmap_offset);
if (dev->postcopy_listening) {
/*
* In postcopy we're using PROT_NONE here to catch anyone
* accessing it before we userfault
*/
prot = PROT_NONE;
}
/*
* We will add memory regions into the array sorted by GPA. Perform a
* binary search to locate the insertion point: it will be at the low
* index.
*/
while (low <= high) {
unsigned int mid = low + (high - low) / 2;
VuDevRegion *cur = &dev->regions[mid];
/* Overlap of GPA addresses. */
if (start_gpa < cur->gpa + cur->size && cur->gpa < end_gpa) {
vu_panic(dev, "regions with overlapping guest physical addresses");
return;
}
if (start_gpa >= cur->gpa + cur->size) {
low = mid + 1;
}
if (start_gpa < cur->gpa) {
high = mid - 1;
}
}
idx = low;
/*
* Convert most of msg_region->mmap_offset to fd_offset. In almost all
* cases, this will leave us with mmap_offset == 0, mmap()'ing only
* what we really need. Only if a memory region would partially cover
* hugetlb pages, we'd get mmap_offset != 0, which usually doesn't happen
* anymore (i.e., modern QEMU).
*
* Note that mmap() with hugetlb would fail if the offset into the file
* is not aligned to the huge page size.
*/
hugepagesize = get_fd_hugepagesize(fd);
if (hugepagesize) {
fd_offset = ALIGN_DOWN(msg_region->mmap_offset, hugepagesize);
mmap_offset = msg_region->mmap_offset - fd_offset;
} else {
fd_offset = msg_region->mmap_offset;
mmap_offset = 0;
}
DPRINT(" fd_offset: 0x%016"PRIx64"\n",
fd_offset);
DPRINT(" new mmap_offset: 0x%016"PRIx64"\n",
mmap_offset);
mmap_addr = mmap(0, msg_region->memory_size + mmap_offset,
prot, MAP_SHARED | MAP_NORESERVE, fd, fd_offset);
if (mmap_addr == MAP_FAILED) {
vu_panic(dev, "region mmap error: %s", strerror(errno));
return;
}
DPRINT(" mmap_addr: 0x%016"PRIx64"\n",
(uint64_t)(uintptr_t)mmap_addr);
#if defined(__linux__)
/* Don't include all guest memory in a coredump. */
madvise(mmap_addr, msg_region->memory_size + mmap_offset,
MADV_DONTDUMP);
#endif
/* Shift all affected entries by 1 to open a hole at idx. */
r = &dev->regions[idx];
memmove(r + 1, r, sizeof(VuDevRegion) * (dev->nregions - idx));
r->gpa = msg_region->guest_phys_addr;
r->size = msg_region->memory_size;
r->qva = msg_region->userspace_addr;
r->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
r->mmap_offset = mmap_offset;
dev->nregions++;
if (dev->postcopy_listening) {
/*
* Return the address to QEMU so that it can translate the ufd
* fault addresses back.
*/
msg_region->userspace_addr = r->mmap_addr + r->mmap_offset;
}
}
static void
vmsg_close_fds(VhostUserMsg *vmsg)
{
int i;
for (i = 0; i < vmsg->fd_num; i++) {
close(vmsg->fds[i]);
}
}
/* Set reply payload.u64 and clear request flags and fd_num */
static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
{
vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->payload.u64 = val;
vmsg->fd_num = 0;
}
/* A test to see if we have userfault available */
static bool
have_userfault(void)
{
#if defined(__linux__) && defined(__NR_userfaultfd) &&\
defined(UFFD_FEATURE_MISSING_SHMEM) &&\
defined(UFFD_FEATURE_MISSING_HUGETLBFS)
/* Now test the kernel we're running on really has the features */
int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
struct uffdio_api api_struct;
if (ufd < 0) {
return false;
}
api_struct.api = UFFD_API;
api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
UFFD_FEATURE_MISSING_HUGETLBFS;
if (ioctl(ufd, UFFDIO_API, &api_struct)) {
close(ufd);
return false;
}
close(ufd);
return true;
#else
return false;
#endif
}
static bool
vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
struct iovec iov = {
.iov_base = (char *)vmsg,
.iov_len = VHOST_USER_HDR_SIZE,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control,
.msg_controllen = sizeof(control),
};
size_t fd_size;
struct cmsghdr *cmsg;
int rc;
do {
rc = recvmsg(conn_fd, &msg, 0);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc < 0) {
vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
return false;
}
vmsg->fd_num = 0;
for (cmsg = CMSG_FIRSTHDR(&msg);
cmsg != NULL;
cmsg = CMSG_NXTHDR(&msg, cmsg))
{
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
fd_size = cmsg->cmsg_len - CMSG_LEN(0);
vmsg->fd_num = fd_size / sizeof(int);
assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
break;
}
}
if (vmsg->size > sizeof(vmsg->payload)) {
vu_panic(dev,
"Error: too big message request: %d, size: vmsg->size: %u, "
"while sizeof(vmsg->payload) = %zu\n",
vmsg->request, vmsg->size, sizeof(vmsg->payload));
goto fail;
}
if (vmsg->size) {
do {
rc = read(conn_fd, &vmsg->payload, vmsg->size);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc <= 0) {
vu_panic(dev, "Error while reading: %s", strerror(errno));
goto fail;
}
assert((uint32_t)rc == vmsg->size);
}
return true;
fail:
vmsg_close_fds(vmsg);
return false;
}
static bool
vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
int rc;
uint8_t *p = (uint8_t *)vmsg;
char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
struct iovec iov = {
.iov_base = (char *)vmsg,
.iov_len = VHOST_USER_HDR_SIZE,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control,
};
struct cmsghdr *cmsg;
memset(control, 0, sizeof(control));
assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
if (vmsg->fd_num > 0) {
size_t fdsize = vmsg->fd_num * sizeof(int);
msg.msg_controllen = CMSG_SPACE(fdsize);
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(fdsize);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
} else {
msg.msg_controllen = 0;
msg.msg_control = NULL;
}
do {
rc = sendmsg(conn_fd, &msg, 0);
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
if (rc <= 0) {
vu_panic(dev, "Error while writing: %s", strerror(errno));
return false;
}
if (vmsg->size) {
do {
if (vmsg->data) {
rc = write(conn_fd, vmsg->data, vmsg->size);
} else {
rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
}
} while (rc < 0 && (errno == EINTR || errno == EAGAIN));
}
if (rc <= 0) {
vu_panic(dev, "Error while writing: %s", strerror(errno));
return false;
}
return true;
}
static bool
vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
/* Set the version in the flags when sending the reply */
vmsg->flags &= ~VHOST_USER_VERSION_MASK;
vmsg->flags |= VHOST_USER_VERSION;
vmsg->flags |= VHOST_USER_REPLY_MASK;
return vu_message_write(dev, conn_fd, vmsg);
}
/*
* Processes a reply on the backend channel.
* Entered with backend_mutex held and releases it before exit.
* Returns true on success.
*/
static bool
vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
{
VhostUserMsg msg_reply;
bool result = false;
if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
result = true;
goto out;
}
if (!vu_message_read_default(dev, dev->backend_fd, &msg_reply)) {
goto out;
}
if (msg_reply.request != vmsg->request) {
DPRINT("Received unexpected msg type. Expected %d received %d",
vmsg->request, msg_reply.request);
goto out;
}
result = msg_reply.payload.u64 == 0;
out:
pthread_mutex_unlock(&dev->backend_mutex);
return result;
}
/* Kick the log_call_fd if required. */
static void
vu_log_kick(VuDev *dev)
{
if (dev->log_call_fd != -1) {
DPRINT("Kicking the QEMU's log...\n");
if (eventfd_write(dev->log_call_fd, 1) < 0) {
vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
}
}
}
static void
vu_log_page(uint8_t *log_table, uint64_t page)
{
DPRINT("Logged dirty guest page: %"PRId64"\n", page);
qatomic_or(&log_table[page / 8], 1 << (page % 8));
}
static void
vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
{
uint64_t page;
if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
!dev->log_table || !length) {
return;
}
assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
page = address / VHOST_LOG_PAGE;
while (page * VHOST_LOG_PAGE < address + length) {
vu_log_page(dev->log_table, page);
page += 1;
}
vu_log_kick(dev);
}
static void
vu_kick_cb(VuDev *dev, int condition, void *data)
{
int index = (intptr_t)data;
VuVirtq *vq = &dev->vq[index];
int sock = vq->kick_fd;
eventfd_t kick_data;
ssize_t rc;
rc = eventfd_read(sock, &kick_data);
if (rc == -1) {
vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
dev->remove_watch(dev, dev->vq[index].kick_fd);
} else {
DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
kick_data, vq->handler, index);
if (vq->handler) {
vq->handler(dev, index);
}
}
}
static bool
vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
vmsg->payload.u64 =
/*
* The following VIRTIO feature bits are supported by our virtqueue
* implementation:
*/
1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
1ULL << VIRTIO_RING_F_INDIRECT_DESC |
1ULL << VIRTIO_RING_F_EVENT_IDX |
1ULL << VIRTIO_F_VERSION_1 |
/* vhost-user feature bits */
1ULL << VHOST_F_LOG_ALL |
1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
if (dev->iface->get_features) {
vmsg->payload.u64 |= dev->iface->get_features(dev);
}
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->fd_num = 0;
DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
return true;
}
static void
vu_set_enable_all_rings(VuDev *dev, bool enabled)
{
uint16_t i;
for (i = 0; i < dev->max_queues; i++) {
dev->vq[i].enable = enabled;
}
}
static bool
vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
dev->features = vmsg->payload.u64;
if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
/*
* We only support devices conforming to VIRTIO 1.0 or
* later
*/
vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
return false;
}
if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
vu_set_enable_all_rings(dev, true);
}
if (dev->iface->set_features) {
dev->iface->set_features(dev, dev->features);
}
return false;
}
static bool
vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
{
return false;
}
static void
vu_close_log(VuDev *dev)
{
if (dev->log_table) {
if (munmap(dev->log_table, dev->log_size) != 0) {
perror("close log munmap() error");
}
dev->log_table = NULL;
}
if (dev->log_call_fd != -1) {
close(dev->log_call_fd);
dev->log_call_fd = -1;
}
}
static bool
vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
{
vu_set_enable_all_rings(dev, false);
return false;
}
static bool
generate_faults(VuDev *dev) {
unsigned int i;
for (i = 0; i < dev->nregions; i++) {
#ifdef UFFDIO_REGISTER
VuDevRegion *dev_region = &dev->regions[i];
int ret;
struct uffdio_register reg_struct;
/*
* We should already have an open ufd. Mark each memory
* range as ufd.
* Discard any mapping we have here; note I can't use MADV_REMOVE
* or fallocate to make the hole since I don't want to lose
* data that's already arrived in the shared process.
* TODO: How to do hugepage
*/
ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
dev_region->size + dev_region->mmap_offset,
MADV_DONTNEED);
if (ret) {
fprintf(stderr,
"%s: Failed to madvise(DONTNEED) region %d: %s\n",
__func__, i, strerror(errno));
}
/*
* Turn off transparent hugepages so we dont get lose wakeups
* in neighbouring pages.
* TODO: Turn this backon later.
*/
ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
dev_region->size + dev_region->mmap_offset,
MADV_NOHUGEPAGE);
if (ret) {
/*
* Note: This can happen legally on kernels that are configured
* without madvise'able hugepages
*/
fprintf(stderr,
"%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
__func__, i, strerror(errno));
}
reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
vu_panic(dev, "%s: Failed to userfault region %d "
"@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
": (ufd=%d)%s\n",
__func__, i,
dev_region->mmap_addr,
dev_region->size, dev_region->mmap_offset,
dev->postcopy_ufd, strerror(errno));
return false;
}
if (!(reg_struct.ioctls & (1ULL << _UFFDIO_COPY))) {
vu_panic(dev, "%s Region (%d) doesn't support COPY",
__func__, i);
return false;
}
DPRINT("%s: region %d: Registered userfault for %"
PRIx64 " + %" PRIx64 "\n", __func__, i,
(uint64_t)reg_struct.range.start,
(uint64_t)reg_struct.range.len);
/* Now it's registered we can let the client at it */
if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
dev_region->size + dev_region->mmap_offset,
PROT_READ | PROT_WRITE)) {
vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
i, strerror(errno));
return false;
}
/* TODO: Stash 'zero' support flags somewhere */
#endif
}
return true;
}
static bool
vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
if (vmsg->fd_num != 1) {
vmsg_close_fds(vmsg);
vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
"should be sent for this message type", vmsg->fd_num);
return false;
}
if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
close(vmsg->fds[0]);
vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
"least %zu bytes and only %d bytes were received",
VHOST_USER_MEM_REG_SIZE, vmsg->size);
return false;
}
if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
close(vmsg->fds[0]);
vu_panic(dev, "failing attempt to hot add memory via "
"VHOST_USER_ADD_MEM_REG message because the backend has "
"no free ram slots available");
return false;
}
/*
* If we are in postcopy mode and we receive a u64 payload with a 0 value
* we know all the postcopy client bases have been received, and we
* should start generating faults.
*/
if (dev->postcopy_listening &&
vmsg->size == sizeof(vmsg->payload.u64) &&
vmsg->payload.u64 == 0) {
(void)generate_faults(dev);
return false;
}
_vu_add_mem_reg(dev, msg_region, vmsg->fds[0]);
close(vmsg->fds[0]);
if (dev->postcopy_listening) {
/* Send the message back to qemu with the addresses filled in. */
vmsg->fd_num = 0;
DPRINT("Successfully added new region in postcopy\n");
return true;
}
DPRINT("Successfully added new region\n");
return false;
}
static inline bool reg_equal(VuDevRegion *vudev_reg,
VhostUserMemoryRegion *msg_reg)
{
if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
vudev_reg->qva == msg_reg->userspace_addr &&
vudev_reg->size == msg_reg->memory_size) {
return true;
}
return false;
}
static bool
vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
unsigned int idx;
VuDevRegion *r;
if (vmsg->fd_num > 1) {
vmsg_close_fds(vmsg);
vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
"should be sent for this message type", vmsg->fd_num);
return false;
}
if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
vmsg_close_fds(vmsg);
vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
"least %zu bytes and only %d bytes were received",
VHOST_USER_MEM_REG_SIZE, vmsg->size);
return false;
}
DPRINT("Removing region:\n");
DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n",
msg_region->guest_phys_addr);
DPRINT(" memory_size: 0x%016"PRIx64"\n",
msg_region->memory_size);
DPRINT(" userspace_addr 0x%016"PRIx64"\n",
msg_region->userspace_addr);
DPRINT(" mmap_offset 0x%016"PRIx64"\n",
msg_region->mmap_offset);
r = vu_gpa_to_mem_region(dev, msg_region->guest_phys_addr);
if (!r || !reg_equal(r, msg_region)) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Specified region not found\n");
return false;
}
/*
* There might be valid cases where we temporarily remove memory regions
* to readd them again, or remove memory regions and don't use the rings
* anymore before we set the ring addresses and restart the device.
*
* Unmap all affected rings, remapping them on demand later. This should
* be a corner case.
*/
unmap_rings(dev, r);
munmap((void *)(uintptr_t)r->mmap_addr, r->size + r->mmap_offset);
idx = r - dev->regions;
assert(idx < dev->nregions);
/* Shift all affected entries by 1 to close the hole. */
memmove(r, r + 1, sizeof(VuDevRegion) * (dev->nregions - idx - 1));
DPRINT("Successfully removed a region\n");
dev->nregions--;
vmsg_close_fds(vmsg);
return false;
}
static bool
vu_get_shared_object(VuDev *dev, VhostUserMsg *vmsg)
{
int fd_num = 0;
int dmabuf_fd = -1;
if (dev->iface->get_shared_object) {
dmabuf_fd = dev->iface->get_shared_object(
dev, &vmsg->payload.object.uuid[0]);
}
if (dmabuf_fd != -1) {
DPRINT("dmabuf_fd found for requested UUID\n");
vmsg->fds[fd_num++] = dmabuf_fd;
}
vmsg->fd_num = fd_num;
return true;
}
static bool
vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
{
VhostUserMemory m = vmsg->payload.memory, *memory = &m;
unsigned int i;
vu_remove_all_mem_regs(dev);
DPRINT("Nregions: %u\n", memory->nregions);
for (i = 0; i < memory->nregions; i++) {
_vu_add_mem_reg(dev, &memory->regions[i], vmsg->fds[i]);
close(vmsg->fds[i]);
}
if (dev->postcopy_listening) {
/* Send the message back to qemu with the addresses filled in */
vmsg->fd_num = 0;
if (!vu_send_reply(dev, dev->sock, vmsg)) {
vu_panic(dev, "failed to respond to set-mem-table for postcopy");
return false;
}
/*
* Wait for QEMU to confirm that it's registered the handler for the
* faults.
*/
if (!dev->read_msg(dev, dev->sock, vmsg) ||
vmsg->size != sizeof(vmsg->payload.u64) ||
vmsg->payload.u64 != 0) {
vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
return false;
}
/* OK, now we can go and register the memory and generate faults */
(void)generate_faults(dev);
return false;
}
for (i = 0; i < dev->max_queues; i++) {
if (dev->vq[i].vring.desc) {
if (map_ring(dev, &dev->vq[i])) {
vu_panic(dev, "remapping queue %d during setmemtable", i);
}
}
}
return false;
}
static bool
vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int fd;
uint64_t log_mmap_size, log_mmap_offset;
void *rc;
if (vmsg->fd_num != 1 ||
vmsg->size != sizeof(vmsg->payload.log)) {
vu_panic(dev, "Invalid log_base message");
return true;
}
fd = vmsg->fds[0];
log_mmap_offset = vmsg->payload.log.mmap_offset;
log_mmap_size = vmsg->payload.log.mmap_size;
DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
DPRINT("Log mmap_size: %"PRId64"\n", log_mmap_size);
rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
log_mmap_offset);
close(fd);
if (rc == MAP_FAILED) {
perror("log mmap error");
}
if (dev->log_table) {
munmap(dev->log_table, dev->log_size);
}
dev->log_table = rc;
dev->log_size = log_mmap_size;
vmsg->size = sizeof(vmsg->payload.u64);
vmsg->fd_num = 0;
return true;
}
static bool
vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
{
if (vmsg->fd_num != 1) {
vu_panic(dev, "Invalid log_fd message");
return false;
}
if (dev->log_call_fd != -1) {
close(dev->log_call_fd);
}
dev->log_call_fd = vmsg->fds[0];
DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
return false;
}
static bool
vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int num = vmsg->payload.state.num;
DPRINT("State.index: %u\n", index);
DPRINT("State.num: %u\n", num);
dev->vq[index].vring.num = num;
return false;
}
static bool
vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
{
struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
unsigned int index = vra->index;
VuVirtq *vq = &dev->vq[index];
DPRINT("vhost_vring_addr:\n");
DPRINT(" index: %d\n", vra->index);
DPRINT(" flags: %d\n", vra->flags);
DPRINT(" desc_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
DPRINT(" used_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
DPRINT(" avail_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
DPRINT(" log_guest_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
vq->vra = *vra;
vq->vring.flags = vra->flags;
vq->vring.log_guest_addr = vra->log_guest_addr;
if (map_ring(dev, vq)) {
vu_panic(dev, "Invalid vring_addr message");
return false;
}
vq->used_idx = le16toh(vq->vring.used->idx);
if (vq->last_avail_idx != vq->used_idx) {
bool resume = dev->iface->queue_is_processed_in_order &&
dev->iface->queue_is_processed_in_order(dev, index);
DPRINT("Last avail index != used index: %u != %u%s\n",
vq->last_avail_idx, vq->used_idx,
resume ? ", resuming" : "");
if (resume) {
vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
}
}
return false;
}
static bool
vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int num = vmsg->payload.state.num;
DPRINT("State.index: %u\n", index);
DPRINT("State.num: %u\n", num);
dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
return false;
}
static bool
vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
DPRINT("State.index: %u\n", index);
vmsg->payload.state.num = dev->vq[index].last_avail_idx;
vmsg->size = sizeof(vmsg->payload.state);
dev->vq[index].started = false;
if (dev->iface->queue_set_started) {
dev->iface->queue_set_started(dev, index, false);
}
if (dev->vq[index].call_fd != -1) {
close(dev->vq[index].call_fd);
dev->vq[index].call_fd = -1;
}
if (dev->vq[index].kick_fd != -1) {
dev->remove_watch(dev, dev->vq[index].kick_fd);
close(dev->vq[index].kick_fd);
dev->vq[index].kick_fd = -1;
}
return true;
}
static bool
vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
if (index >= dev->max_queues) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Invalid queue index: %u", index);
return false;
}
if (nofd) {
vmsg_close_fds(vmsg);
return true;
}
if (vmsg->fd_num != 1) {
vmsg_close_fds(vmsg);
vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
return false;
}
return true;
}
static int
inflight_desc_compare(const void *a, const void *b)
{
VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
*desc1 = (VuVirtqInflightDesc *)b;
if (desc1->counter > desc0->counter &&
(desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
return 1;
}
return -1;
}
static int
vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
{
int i = 0;
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
return 0;
}
if (unlikely(!vq->inflight)) {
return -1;
}
if (unlikely(!vq->inflight->version)) {
/* initialize the buffer */
vq->inflight->version = INFLIGHT_VERSION;
return 0;
}
vq->used_idx = le16toh(vq->vring.used->idx);
vq->resubmit_num = 0;
vq->resubmit_list = NULL;
vq->counter = 0;
if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
barrier();
vq->inflight->used_idx = vq->used_idx;
}
for (i = 0; i < vq->inflight->desc_num; i++) {
if (vq->inflight->desc[i].inflight == 1) {
vq->inuse++;
}
}
vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
if (vq->inuse) {
vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
if (!vq->resubmit_list) {
return -1;
}
for (i = 0; i < vq->inflight->desc_num; i++) {
if (vq->inflight->desc[i].inflight) {
vq->resubmit_list[vq->resubmit_num].index = i;
vq->resubmit_list[vq->resubmit_num].counter =
vq->inflight->desc[i].counter;
vq->resubmit_num++;
}
}
if (vq->resubmit_num > 1) {
qsort(vq->resubmit_list, vq->resubmit_num,
sizeof(VuVirtqInflightDesc), inflight_desc_compare);
}
vq->counter = vq->resubmit_list[0].counter + 1;
}
/* in case of I/O hang after reconnecting */
if (eventfd_write(vq->kick_fd, 1)) {
return -1;
}
return 0;
}
static bool
vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].kick_fd != -1) {
dev->remove_watch(dev, dev->vq[index].kick_fd);
close(dev->vq[index].kick_fd);
dev->vq[index].kick_fd = -1;
}
dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
dev->vq[index].started = true;
if (dev->iface->queue_set_started) {
dev->iface->queue_set_started(dev, index, true);
}
if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
vu_kick_cb, (void *)(long)index);
DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
dev->vq[index].kick_fd, index);
}
if (vu_check_queue_inflights(dev, &dev->vq[index])) {
vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
}
return false;
}
void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
vu_queue_handler_cb handler)
{
int qidx = vq - dev->vq;
vq->handler = handler;
if (vq->kick_fd >= 0) {
if (handler) {
dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
vu_kick_cb, (void *)(long)qidx);
} else {
dev->remove_watch(dev, vq->kick_fd);
}
}
}
bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
int size, int offset)
{
int qidx = vq - dev->vq;
int fd_num = 0;
VhostUserMsg vmsg = {
.request = VHOST_USER_BACKEND_VRING_HOST_NOTIFIER_MSG,
.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
.size = sizeof(vmsg.payload.area),
.payload.area = {
.u64 = qidx & VHOST_USER_VRING_IDX_MASK,
.size = size,
.offset = offset,
},
};
if (fd == -1) {
vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
} else {
vmsg.fds[fd_num++] = fd;
}
vmsg.fd_num = fd_num;
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD)) {
return false;
}
pthread_mutex_lock(&dev->backend_mutex);
if (!vu_message_write(dev, dev->backend_fd, &vmsg)) {
pthread_mutex_unlock(&dev->backend_mutex);
return false;
}
/* Also unlocks the backend_mutex */
return vu_process_message_reply(dev, &vmsg);
}
bool
vu_lookup_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN],
int *dmabuf_fd)
{
bool result = false;
VhostUserMsg msg_reply;
VhostUserMsg msg = {
.request = VHOST_USER_BACKEND_SHARED_OBJECT_LOOKUP,
.size = sizeof(msg.payload.object),
.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
};
memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
return false;
}
pthread_mutex_lock(&dev->backend_mutex);
if (!vu_message_write(dev, dev->backend_fd, &msg)) {
goto out;
}
if (!vu_message_read_default(dev, dev->backend_fd, &msg_reply)) {
goto out;
}
if (msg_reply.request != msg.request) {
DPRINT("Received unexpected msg type. Expected %d, received %d",
msg.request, msg_reply.request);
goto out;
}
if (msg_reply.fd_num != 1) {
DPRINT("Received unexpected number of fds. Expected 1, received %d",
msg_reply.fd_num);
goto out;
}
*dmabuf_fd = msg_reply.fds[0];
result = *dmabuf_fd > 0 && msg_reply.payload.u64 == 0;
out:
pthread_mutex_unlock(&dev->backend_mutex);
return result;
}
static bool
vu_send_message(VuDev *dev, VhostUserMsg *vmsg)
{
bool result = false;
pthread_mutex_lock(&dev->backend_mutex);
if (!vu_message_write(dev, dev->backend_fd, vmsg)) {
goto out;
}
result = true;
out:
pthread_mutex_unlock(&dev->backend_mutex);
return result;
}
bool
vu_add_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN])
{
VhostUserMsg msg = {
.request = VHOST_USER_BACKEND_SHARED_OBJECT_ADD,
.size = sizeof(msg.payload.object),
.flags = VHOST_USER_VERSION,
};
memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
return false;
}
return vu_send_message(dev, &msg);
}
bool
vu_rm_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN])
{
VhostUserMsg msg = {
.request = VHOST_USER_BACKEND_SHARED_OBJECT_REMOVE,
.size = sizeof(msg.payload.object),
.flags = VHOST_USER_VERSION,
};
memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
return false;
}
return vu_send_message(dev, &msg);
}
static bool
vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].call_fd != -1) {
close(dev->vq[index].call_fd);
dev->vq[index].call_fd = -1;
}
dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
/* in case of I/O hang after reconnecting */
if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
return -1;
}
DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
return false;
}
static bool
vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
{
int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
if (!vu_check_queue_msg_file(dev, vmsg)) {
return false;
}
if (dev->vq[index].err_fd != -1) {
close(dev->vq[index].err_fd);
dev->vq[index].err_fd = -1;
}
dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
return false;
}
static bool
vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
/*
* Note that we support, but intentionally do not set,
* VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
* a device implementation can return it in its callback
* (get_protocol_features) if it wants to use this for
* simulation, but it is otherwise not desirable (if even
* implemented by the frontend.)
*/
uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1ULL << VHOST_USER_PROTOCOL_F_BACKEND_REQ |
1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1ULL << VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD |
1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
if (have_userfault()) {
features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
}
if (dev->iface->get_config && dev->iface->set_config) {
features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
}
if (dev->iface->get_protocol_features) {
features |= dev->iface->get_protocol_features(dev);
}
vmsg_set_reply_u64(vmsg, features);
return true;
}
static bool
vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
uint64_t features = vmsg->payload.u64;
DPRINT("u64: 0x%016"PRIx64"\n", features);
dev->protocol_features = vmsg->payload.u64;
if (vu_has_protocol_feature(dev,
VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
(!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ) ||
!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
/*
* The use case for using messages for kick/call is simulation, to make
* the kick and call synchronous. To actually get that behaviour, both
* of the other features are required.
* Theoretically, one could use only kick messages, or do them without
* having F_REPLY_ACK, but too many (possibly pending) messages on the
* socket will eventually cause the frontend to hang, to avoid this in
* scenarios where not desired enforce that the settings are in a way
* that actually enables the simulation case.
*/
vu_panic(dev,
"F_IN_BAND_NOTIFICATIONS requires F_BACKEND_REQ && F_REPLY_ACK");
return false;
}
if (dev->iface->set_protocol_features) {
dev->iface->set_protocol_features(dev, features);
}
return false;
}
static bool
vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
vmsg_set_reply_u64(vmsg, dev->max_queues);
return true;
}
static bool
vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
unsigned int enable = vmsg->payload.state.num;
DPRINT("State.index: %u\n", index);
DPRINT("State.enable: %u\n", enable);
if (index >= dev->max_queues) {
vu_panic(dev, "Invalid vring_enable index: %u", index);
return false;
}
dev->vq[index].enable = enable;
return false;
}
static bool
vu_set_backend_req_fd(VuDev *dev, VhostUserMsg *vmsg)
{
if (vmsg->fd_num != 1) {
vu_panic(dev, "Invalid backend_req_fd message (%d fd's)", vmsg->fd_num);
return false;
}
if (dev->backend_fd != -1) {
close(dev->backend_fd);
}
dev->backend_fd = vmsg->fds[0];
DPRINT("Got backend_fd: %d\n", vmsg->fds[0]);
return false;
}
static bool
vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
{
int ret = -1;
if (dev->iface->get_config) {
ret = dev->iface->get_config(dev, vmsg->payload.config.region,
vmsg->payload.config.size);
}
if (ret) {
/* resize to zero to indicate an error to frontend */
vmsg->size = 0;
}
return true;
}
static bool
vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
{
int ret = -1;
if (dev->iface->set_config) {
ret = dev->iface->set_config(dev, vmsg->payload.config.region,
vmsg->payload.config.offset,
vmsg->payload.config.size,
vmsg->payload.config.flags);
if (ret) {
vu_panic(dev, "Set virtio configuration space failed");
}
}
return false;
}
static bool
vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
{
#ifdef UFFDIO_API
struct uffdio_api api_struct;
dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
vmsg->size = 0;
#else
dev->postcopy_ufd = -1;
#endif
if (dev->postcopy_ufd == -1) {
vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
goto out;
}
#ifdef UFFDIO_API
api_struct.api = UFFD_API;
api_struct.features = 0;
if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
close(dev->postcopy_ufd);
dev->postcopy_ufd = -1;
goto out;
}
/* TODO: Stash feature flags somewhere */
#endif
out:
/* Return a ufd to the QEMU */
vmsg->fd_num = 1;
vmsg->fds[0] = dev->postcopy_ufd;
return true; /* = send a reply */
}
static bool
vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
{
if (dev->nregions) {
vu_panic(dev, "Regions already registered at postcopy-listen");
vmsg_set_reply_u64(vmsg, -1);
return true;
}
dev->postcopy_listening = true;
vmsg_set_reply_u64(vmsg, 0);
return true;
}
static bool
vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
{
DPRINT("%s: Entry\n", __func__);
dev->postcopy_listening = false;
if (dev->postcopy_ufd > 0) {
close(dev->postcopy_ufd);
dev->postcopy_ufd = -1;
DPRINT("%s: Done close\n", __func__);
}
vmsg_set_reply_u64(vmsg, 0);
DPRINT("%s: exit\n", __func__);
return true;
}
static inline uint64_t
vu_inflight_queue_size(uint16_t queue_size)
{
return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
sizeof(uint16_t), INFLIGHT_ALIGNMENT);
}
#ifdef MFD_ALLOW_SEALING
static void *
memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
{
void *ptr;
int ret;
*fd = memfd_create(name, MFD_ALLOW_SEALING);
if (*fd < 0) {
return NULL;
}
ret = ftruncate(*fd, size);
if (ret < 0) {
close(*fd);
return NULL;
}
ret = fcntl(*fd, F_ADD_SEALS, flags);
if (ret < 0) {
close(*fd);
return NULL;
}
ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
if (ptr == MAP_FAILED) {
close(*fd);
return NULL;
}
return ptr;
}
#endif
static bool
vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
{
int fd = -1;
void *addr = NULL;
uint64_t mmap_size;
uint16_t num_queues, queue_size;
if (vmsg->size != sizeof(vmsg->payload.inflight)) {
vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
vmsg->payload.inflight.mmap_size = 0;
return true;
}
num_queues = vmsg->payload.inflight.num_queues;
queue_size = vmsg->payload.inflight.queue_size;
DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
#ifdef MFD_ALLOW_SEALING
addr = memfd_alloc("vhost-inflight", mmap_size,
F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
&fd);
#else
vu_panic(dev, "Not implemented: memfd support is missing");
#endif
if (!addr) {
vu_panic(dev, "Failed to alloc vhost inflight area");
vmsg->payload.inflight.mmap_size = 0;
return true;
}
memset(addr, 0, mmap_size);
dev->inflight_info.addr = addr;
dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
dev->inflight_info.fd = vmsg->fds[0] = fd;
vmsg->fd_num = 1;
vmsg->payload.inflight.mmap_offset = 0;
DPRINT("send inflight mmap_size: %"PRId64"\n",
vmsg->payload.inflight.mmap_size);
DPRINT("send inflight mmap offset: %"PRId64"\n",
vmsg->payload.inflight.mmap_offset);
return true;
}
static bool
vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
{
int fd, i;
uint64_t mmap_size, mmap_offset;
uint16_t num_queues, queue_size;
void *rc;
if (vmsg->fd_num != 1 ||
vmsg->size != sizeof(vmsg->payload.inflight)) {
vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
vmsg->size, vmsg->fd_num);
return false;
}
fd = vmsg->fds[0];
mmap_size = vmsg->payload.inflight.mmap_size;
mmap_offset = vmsg->payload.inflight.mmap_offset;
num_queues = vmsg->payload.inflight.num_queues;
queue_size = vmsg->payload.inflight.queue_size;
DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
fd, mmap_offset);
if (rc == MAP_FAILED) {
vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
return false;
}
if (dev->inflight_info.fd) {
close(dev->inflight_info.fd);
}
if (dev->inflight_info.addr) {
munmap(dev->inflight_info.addr, dev->inflight_info.size);
}
dev->inflight_info.fd = fd;
dev->inflight_info.addr = rc;
dev->inflight_info.size = mmap_size;
for (i = 0; i < num_queues; i++) {
dev->vq[i].inflight = (VuVirtqInflight *)rc;
dev->vq[i].inflight->desc_num = queue_size;
rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
}
return false;
}
static bool
vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
{
unsigned int index = vmsg->payload.state.index;
if (index >= dev->max_queues) {
vu_panic(dev, "Invalid queue index: %u", index);
return false;
}
DPRINT("Got kick message: handler:%p idx:%u\n",
dev->vq[index].handler, index);
if (!dev->vq[index].started) {
dev->vq[index].started = true;
if (dev->iface->queue_set_started) {
dev->iface->queue_set_started(dev, index, true);
}
}
if (dev->vq[index].handler) {
dev->vq[index].handler(dev, index);
}
return false;
}
static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
{
vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
return true;
}
static bool
vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
{
int do_reply = 0;
/* Print out generic part of the request. */
DPRINT("================ Vhost user message ================\n");
DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
vmsg->request);
DPRINT("Flags: 0x%x\n", vmsg->flags);
DPRINT("Size: %u\n", vmsg->size);
if (vmsg->fd_num) {
int i;
DPRINT("Fds:");
for (i = 0; i < vmsg->fd_num; i++) {
DPRINT(" %d", vmsg->fds[i]);
}
DPRINT("\n");
}
if (dev->iface->process_msg &&
dev->iface->process_msg(dev, vmsg, &do_reply)) {
return do_reply;
}
switch (vmsg->request) {
case VHOST_USER_GET_FEATURES:
return vu_get_features_exec(dev, vmsg);
case VHOST_USER_SET_FEATURES:
return vu_set_features_exec(dev, vmsg);
case VHOST_USER_GET_PROTOCOL_FEATURES:
return vu_get_protocol_features_exec(dev, vmsg);
case VHOST_USER_SET_PROTOCOL_FEATURES:
return vu_set_protocol_features_exec(dev, vmsg);
case VHOST_USER_SET_OWNER:
return vu_set_owner_exec(dev, vmsg);
case VHOST_USER_RESET_OWNER:
return vu_reset_device_exec(dev, vmsg);
case VHOST_USER_SET_MEM_TABLE:
return vu_set_mem_table_exec(dev, vmsg);
case VHOST_USER_SET_LOG_BASE:
return vu_set_log_base_exec(dev, vmsg);
case VHOST_USER_SET_LOG_FD:
return vu_set_log_fd_exec(dev, vmsg);
case VHOST_USER_SET_VRING_NUM:
return vu_set_vring_num_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ADDR:
return vu_set_vring_addr_exec(dev, vmsg);
case VHOST_USER_SET_VRING_BASE:
return vu_set_vring_base_exec(dev, vmsg);
case VHOST_USER_GET_VRING_BASE:
return vu_get_vring_base_exec(dev, vmsg);
case VHOST_USER_SET_VRING_KICK:
return vu_set_vring_kick_exec(dev, vmsg);
case VHOST_USER_SET_VRING_CALL:
return vu_set_vring_call_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ERR:
return vu_set_vring_err_exec(dev, vmsg);
case VHOST_USER_GET_QUEUE_NUM:
return vu_get_queue_num_exec(dev, vmsg);
case VHOST_USER_SET_VRING_ENABLE:
return vu_set_vring_enable_exec(dev, vmsg);
case VHOST_USER_SET_BACKEND_REQ_FD:
return vu_set_backend_req_fd(dev, vmsg);
case VHOST_USER_GET_CONFIG:
return vu_get_config(dev, vmsg);
case VHOST_USER_SET_CONFIG:
return vu_set_config(dev, vmsg);
case VHOST_USER_NONE:
/* if you need processing before exit, override iface->process_msg */
exit(0);
case VHOST_USER_POSTCOPY_ADVISE:
return vu_set_postcopy_advise(dev, vmsg);
case VHOST_USER_POSTCOPY_LISTEN:
return vu_set_postcopy_listen(dev, vmsg);
case VHOST_USER_POSTCOPY_END:
return vu_set_postcopy_end(dev, vmsg);
case VHOST_USER_GET_INFLIGHT_FD:
return vu_get_inflight_fd(dev, vmsg);
case VHOST_USER_SET_INFLIGHT_FD:
return vu_set_inflight_fd(dev, vmsg);
case VHOST_USER_VRING_KICK:
return vu_handle_vring_kick(dev, vmsg);
case VHOST_USER_GET_MAX_MEM_SLOTS:
return vu_handle_get_max_memslots(dev, vmsg);
case VHOST_USER_ADD_MEM_REG:
return vu_add_mem_reg(dev, vmsg);
case VHOST_USER_REM_MEM_REG:
return vu_rem_mem_reg(dev, vmsg);
case VHOST_USER_GET_SHARED_OBJECT:
return vu_get_shared_object(dev, vmsg);
default:
vmsg_close_fds(vmsg);
vu_panic(dev, "Unhandled request: %d", vmsg->request);
}
return false;
}
bool
vu_dispatch(VuDev *dev)
{
VhostUserMsg vmsg = { 0, };
int reply_requested;
bool need_reply, success = false;
if (!dev->read_msg(dev, dev->sock, &vmsg)) {
goto end;
}
need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
reply_requested = vu_process_message(dev, &vmsg);
if (!reply_requested && need_reply) {
vmsg_set_reply_u64(&vmsg, 0);
reply_requested = 1;
}
if (!reply_requested) {
success = true;
goto end;
}
if (!vu_send_reply(dev, dev->sock, &vmsg)) {
goto end;
}
success = true;
end:
free(vmsg.data);
return success;
}
void
vu_deinit(VuDev *dev)
{
unsigned int i;
vu_remove_all_mem_regs(dev);
for (i = 0; i < dev->max_queues; i++) {
VuVirtq *vq = &dev->vq[i];
if (vq->call_fd != -1) {
close(vq->call_fd);
vq->call_fd = -1;
}
if (vq->kick_fd != -1) {
dev->remove_watch(dev, vq->kick_fd);
close(vq->kick_fd);
vq->kick_fd = -1;
}
if (vq->err_fd != -1) {
close(vq->err_fd);
vq->err_fd = -1;
}
if (vq->resubmit_list) {
free(vq->resubmit_list);
vq->resubmit_list = NULL;
}
vq->inflight = NULL;
}
if (dev->inflight_info.addr) {
munmap(dev->inflight_info.addr, dev->inflight_info.size);
dev->inflight_info.addr = NULL;
}
if (dev->inflight_info.fd > 0) {
close(dev->inflight_info.fd);
dev->inflight_info.fd = -1;
}
vu_close_log(dev);
if (dev->backend_fd != -1) {
close(dev->backend_fd);
dev->backend_fd = -1;
}
pthread_mutex_destroy(&dev->backend_mutex);
if (dev->sock != -1) {
close(dev->sock);
}
free(dev->vq);
dev->vq = NULL;
free(dev->regions);
dev->regions = NULL;
}
bool
vu_init(VuDev *dev,
uint16_t max_queues,
int socket,
vu_panic_cb panic,
vu_read_msg_cb read_msg,
vu_set_watch_cb set_watch,
vu_remove_watch_cb remove_watch,
const VuDevIface *iface)
{
uint16_t i;
assert(max_queues > 0);
assert(socket >= 0);
assert(set_watch);
assert(remove_watch);
assert(iface);
assert(panic);
memset(dev, 0, sizeof(*dev));
dev->sock = socket;
dev->panic = panic;
dev->read_msg = read_msg ? read_msg : vu_message_read_default;
dev->set_watch = set_watch;
dev->remove_watch = remove_watch;
dev->iface = iface;
dev->log_call_fd = -1;
pthread_mutex_init(&dev->backend_mutex, NULL);
dev->backend_fd = -1;
dev->max_queues = max_queues;
dev->regions = malloc(VHOST_USER_MAX_RAM_SLOTS * sizeof(dev->regions[0]));
if (!dev->regions) {
DPRINT("%s: failed to malloc mem regions\n", __func__);
return false;
}
dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
if (!dev->vq) {
DPRINT("%s: failed to malloc virtqueues\n", __func__);
free(dev->regions);
dev->regions = NULL;
return false;
}
for (i = 0; i < max_queues; i++) {
dev->vq[i] = (VuVirtq) {
.call_fd = -1, .kick_fd = -1, .err_fd = -1,
.notification = true,
};
}
return true;
}
VuVirtq *
vu_get_queue(VuDev *dev, int qidx)
{
assert(qidx < dev->max_queues);
return &dev->vq[qidx];
}
bool
vu_queue_enabled(VuDev *dev, VuVirtq *vq)
{
return vq->enable;
}
bool
vu_queue_started(const VuDev *dev, const VuVirtq *vq)
{
return vq->started;
}
static inline uint16_t
vring_avail_flags(VuVirtq *vq)
{
return le16toh(vq->vring.avail->flags);
}
static inline uint16_t
vring_avail_idx(VuVirtq *vq)
{
vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
return vq->shadow_avail_idx;
}
static inline uint16_t
vring_avail_ring(VuVirtq *vq, int i)
{
return le16toh(vq->vring.avail->ring[i]);
}
static inline uint16_t
vring_get_used_event(VuVirtq *vq)
{
return vring_avail_ring(vq, vq->vring.num);
}
static int
virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
{
uint16_t num_heads = vring_avail_idx(vq) - idx;
/* Check it isn't doing very strange things with descriptor numbers. */
if (num_heads > vq->vring.num) {
vu_panic(dev, "Guest moved used index from %u to %u",
idx, vq->shadow_avail_idx);
return -1;
}
if (num_heads) {
/* On success, callers read a descriptor at vq->last_avail_idx.
* Make sure descriptor read does not bypass avail index read. */
smp_rmb();
}
return num_heads;
}
static bool
virtqueue_get_head(VuDev *dev, VuVirtq *vq,
unsigned int idx, unsigned int *head)
{
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
*head = vring_avail_ring(vq, idx % vq->vring.num);
/* If their number is silly, that's a fatal mistake. */
if (*head >= vq->vring.num) {
vu_panic(dev, "Guest says index %u is available", *head);
return false;
}
return true;
}
static int
virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
uint64_t addr, size_t len)
{
struct vring_desc *ori_desc;
uint64_t read_len;
if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
return -1;
}
if (len == 0) {
return -1;
}
while (len) {
read_len = len;
ori_desc = vu_gpa_to_va(dev, &read_len, addr);
if (!ori_desc) {
return -1;
}
memcpy(desc, ori_desc, read_len);
len -= read_len;
addr += read_len;
desc += read_len;
}
return 0;
}
enum {
VIRTQUEUE_READ_DESC_ERROR = -1,
VIRTQUEUE_READ_DESC_DONE = 0, /* end of chain */
VIRTQUEUE_READ_DESC_MORE = 1, /* more buffers in chain */
};
static int
virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
int i, unsigned int max, unsigned int *next)
{
/* If this descriptor says it doesn't chain, we're done. */
if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
return VIRTQUEUE_READ_DESC_DONE;
}
/* Check they're not leading us off end of descriptors. */
*next = le16toh(desc[i].next);
/* Make sure compiler knows to grab that: we don't want it changing! */
smp_wmb();
if (*next >= max) {
vu_panic(dev, "Desc next is %u", *next);
return VIRTQUEUE_READ_DESC_ERROR;
}
return VIRTQUEUE_READ_DESC_MORE;
}
void
vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
unsigned int *out_bytes,
unsigned max_in_bytes, unsigned max_out_bytes)
{
unsigned int idx;
unsigned int total_bufs, in_total, out_total;
int rc;
idx = vq->last_avail_idx;
total_bufs = in_total = out_total = 0;
if (!vu_is_vq_usable(dev, vq)) {
goto done;
}
while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
unsigned int max, desc_len, num_bufs, indirect = 0;
uint64_t desc_addr, read_len;
struct vring_desc *desc;
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
unsigned int i;
max = vq->vring.num;
num_bufs = total_bufs;
if (!virtqueue_get_head(dev, vq, idx++, &i)) {
goto err;
}
desc = vq->vring.desc;
if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
goto err;
}
/* If we've got too many, that implies a descriptor loop. */
if (num_bufs >= max) {
vu_panic(dev, "Looped descriptor");
goto err;
}
/* loop over the indirect descriptor table */
indirect = 1;
desc_addr = le64toh(desc[i].addr);
desc_len = le32toh(desc[i].len);
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
goto err;
}
num_bufs = i = 0;
}
do {
/* If we've got too many, that implies a descriptor loop. */
if (++num_bufs > max) {
vu_panic(dev, "Looped descriptor");
goto err;
}
if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
in_total += le32toh(desc[i].len);
} else {
out_total += le32toh(desc[i].len);
}
if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
goto done;
}
rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
} while (rc == VIRTQUEUE_READ_DESC_MORE);
if (rc == VIRTQUEUE_READ_DESC_ERROR) {
goto err;
}
if (!indirect) {
total_bufs = num_bufs;
} else {
total_bufs++;
}
}
if (rc < 0) {
goto err;
}
done:
if (in_bytes) {
*in_bytes = in_total;
}
if (out_bytes) {
*out_bytes = out_total;
}
return;
err:
in_total = out_total = 0;
goto done;
}
bool
vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
unsigned int out_bytes)
{
unsigned int in_total, out_total;
vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
in_bytes, out_bytes);
return in_bytes <= in_total && out_bytes <= out_total;
}
/* Fetch avail_idx from VQ memory only when we really need to know if
* guest has added some buffers. */
bool
vu_queue_empty(VuDev *dev, VuVirtq *vq)
{
if (!vu_is_vq_usable(dev, vq)) {
return true;
}
if (vq->shadow_avail_idx != vq->last_avail_idx) {
return false;
}
return vring_avail_idx(vq) == vq->last_avail_idx;
}
static bool
vring_notify(VuDev *dev, VuVirtq *vq)
{
uint16_t old, new;
bool v;
/* We need to expose used array entries before checking used event. */
smp_mb();
/* Always notify when queue is empty (when feature acknowledge) */
if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
!vq->inuse && vu_queue_empty(dev, vq)) {
return true;
}
if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
}
v = vq->signalled_used_valid;
vq->signalled_used_valid = true;
old = vq->signalled_used;
new = vq->signalled_used = vq->used_idx;
return !v || vring_need_event(vring_get_used_event(vq), new, old);
}
static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
{
if (!vu_is_vq_usable(dev, vq)) {
return;
}
if (!vring_notify(dev, vq)) {
DPRINT("skipped notify...\n");
return;
}
if (vq->call_fd < 0 &&
vu_has_protocol_feature(dev,
VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ)) {
VhostUserMsg vmsg = {
.request = VHOST_USER_BACKEND_VRING_CALL,
.flags = VHOST_USER_VERSION,
.size = sizeof(vmsg.payload.state),
.payload.state = {
.index = vq - dev->vq,
},
};
bool ack = sync &&
vu_has_protocol_feature(dev,
VHOST_USER_PROTOCOL_F_REPLY_ACK);
if (ack) {
vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
}
vu_message_write(dev, dev->backend_fd, &vmsg);
if (ack) {
vu_message_read_default(dev, dev->backend_fd, &vmsg);
}
return;
}
if (eventfd_write(vq->call_fd, 1) < 0) {
vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
}
}
void vu_queue_notify(VuDev *dev, VuVirtq *vq)
{
_vu_queue_notify(dev, vq, false);
}
void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
{
_vu_queue_notify(dev, vq, true);
}
void vu_config_change_msg(VuDev *dev)
{
VhostUserMsg vmsg = {
.request = VHOST_USER_BACKEND_CONFIG_CHANGE_MSG,
.flags = VHOST_USER_VERSION,
};
vu_message_write(dev, dev->backend_fd, &vmsg);
}
static inline void
vring_used_flags_set_bit(VuVirtq *vq, int mask)
{
uint16_t *flags;
flags = (uint16_t *)((char*)vq->vring.used +
offsetof(struct vring_used, flags));
*flags = htole16(le16toh(*flags) | mask);
}
static inline void
vring_used_flags_unset_bit(VuVirtq *vq, int mask)
{
uint16_t *flags;
flags = (uint16_t *)((char*)vq->vring.used +
offsetof(struct vring_used, flags));
*flags = htole16(le16toh(*flags) & ~mask);
}
static inline void
vring_set_avail_event(VuVirtq *vq, uint16_t val)
{
uint16_t val_le = htole16(val);
if (!vq->notification) {
return;
}
memcpy(&vq->vring.used->ring[vq->vring.num], &val_le, sizeof(uint16_t));
}
void
vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
{
vq->notification = enable;
if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
vring_set_avail_event(vq, vring_avail_idx(vq));
} else if (enable) {
vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
} else {
vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
}
if (enable) {
/* Expose avail event/used flags before caller checks the avail idx. */
smp_mb();
}
}
static bool
virtqueue_map_desc(VuDev *dev,
unsigned int *p_num_sg, struct iovec *iov,
unsigned int max_num_sg, bool is_write,
uint64_t pa, size_t sz)
{
unsigned num_sg = *p_num_sg;
assert(num_sg <= max_num_sg);
if (!sz) {
vu_panic(dev, "virtio: zero sized buffers are not allowed");
return false;
}
while (sz) {
uint64_t len = sz;
if (num_sg == max_num_sg) {
vu_panic(dev, "virtio: too many descriptors in indirect table");
return false;
}
iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
if (iov[num_sg].iov_base == NULL) {
vu_panic(dev, "virtio: invalid address for buffers");
return false;
}
iov[num_sg].iov_len = len;
num_sg++;
sz -= len;
pa += len;
}
*p_num_sg = num_sg;
return true;
}
static void *
virtqueue_alloc_element(size_t sz,
unsigned out_num, unsigned in_num)
{
VuVirtqElement *elem;
size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
assert(sz >= sizeof(VuVirtqElement));
elem = malloc(out_sg_end);
if (!elem) {
DPRINT("%s: failed to malloc virtqueue element\n", __func__);
return NULL;
}
elem->out_num = out_num;
elem->in_num = in_num;
elem->in_sg = (void *)elem + in_sg_ofs;
elem->out_sg = (void *)elem + out_sg_ofs;
return elem;
}
static void *
vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
{
struct vring_desc *desc = vq->vring.desc;
uint64_t desc_addr, read_len;
unsigned int desc_len;
unsigned int max = vq->vring.num;
unsigned int i = idx;
VuVirtqElement *elem;
unsigned int out_num = 0, in_num = 0;
struct iovec iov[VIRTQUEUE_MAX_SIZE];
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
int rc;
if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
return NULL;
}
/* loop over the indirect descriptor table */
desc_addr = le64toh(desc[i].addr);
desc_len = le32toh(desc[i].len);
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
return NULL;
}
i = 0;
}
/* Collect all the descriptors */
do {
if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
VIRTQUEUE_MAX_SIZE - out_num, true,
le64toh(desc[i].addr),
le32toh(desc[i].len))) {
return NULL;
}
} else {
if (in_num) {
vu_panic(dev, "Incorrect order for descriptors");
return NULL;
}
if (!virtqueue_map_desc(dev, &out_num, iov,
VIRTQUEUE_MAX_SIZE, false,
le64toh(desc[i].addr),
le32toh(desc[i].len))) {
return NULL;
}
}
/* If we've got too many, that implies a descriptor loop. */
if ((in_num + out_num) > max) {
vu_panic(dev, "Looped descriptor");
return NULL;
}
rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
} while (rc == VIRTQUEUE_READ_DESC_MORE);
if (rc == VIRTQUEUE_READ_DESC_ERROR) {
vu_panic(dev, "read descriptor error");
return NULL;
}
/* Now copy what we have collected and mapped */
elem = virtqueue_alloc_element(sz, out_num, in_num);
if (!elem) {
return NULL;
}
elem->index = idx;
for (i = 0; i < out_num; i++) {
elem->out_sg[i] = iov[i];
}
for (i = 0; i < in_num; i++) {
elem->in_sg[i] = iov[out_num + i];
}
return elem;
}
static int
vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
{
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
return 0;
}
if (unlikely(!vq->inflight)) {
return -1;
}
vq->inflight->desc[desc_idx].counter = vq->counter++;
vq->inflight->desc[desc_idx].inflight = 1;
return 0;
}
static int
vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
{
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
return 0;
}
if (unlikely(!vq->inflight)) {
return -1;
}
vq->inflight->last_batch_head = desc_idx;
return 0;
}
static int
vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
{
if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
return 0;
}
if (unlikely(!vq->inflight)) {
return -1;
}
barrier();
vq->inflight->desc[desc_idx].inflight = 0;
barrier();
vq->inflight->used_idx = vq->used_idx;
return 0;
}
void *
vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
{
int i;
unsigned int head;
VuVirtqElement *elem;
if (!vu_is_vq_usable(dev, vq)) {
return NULL;
}
if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
i = (--vq->resubmit_num);
elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
if (!vq->resubmit_num) {
free(vq->resubmit_list);
vq->resubmit_list = NULL;
}
return elem;
}
if (vu_queue_empty(dev, vq)) {
return NULL;
}
/*
* Needed after virtio_queue_empty(), see comment in
* virtqueue_num_heads().
*/
smp_rmb();
if (vq->inuse >= vq->vring.num) {
vu_panic(dev, "Virtqueue size exceeded");
return NULL;
}
if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
return NULL;
}
if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
vring_set_avail_event(vq, vq->last_avail_idx);
}
elem = vu_queue_map_desc(dev, vq, head, sz);
if (!elem) {
return NULL;
}
vq->inuse++;
vu_queue_inflight_get(dev, vq, head);
return elem;
}
static void
vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
size_t len)
{
vq->inuse--;
/* unmap, when DMA support is added */
}
void
vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
size_t len)
{
vq->last_avail_idx--;
vu_queue_detach_element(dev, vq, elem, len);
}
bool
vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
{
if (num > vq->inuse) {
return false;
}
vq->last_avail_idx -= num;
vq->inuse -= num;
return true;
}
static inline
void vring_used_write(VuDev *dev, VuVirtq *vq,
struct vring_used_elem *uelem, int i)
{
struct vring_used *used = vq->vring.used;
used->ring[i] = *uelem;
vu_log_write(dev, vq->vring.log_guest_addr +
offsetof(struct vring_used, ring[i]),
sizeof(used->ring[i]));
}
static void
vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem,
unsigned int len)
{
struct vring_desc *desc = vq->vring.desc;
unsigned int i, max, min, desc_len;
uint64_t desc_addr, read_len;
struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
unsigned num_bufs = 0;
max = vq->vring.num;
i = elem->index;
if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
vu_panic(dev, "Invalid size for indirect buffer table");
return;
}
/* loop over the indirect descriptor table */
desc_addr = le64toh(desc[i].addr);
desc_len = le32toh(desc[i].len);
max = desc_len / sizeof(struct vring_desc);
read_len = desc_len;
desc = vu_gpa_to_va(dev, &read_len, desc_addr);
if (unlikely(desc && read_len != desc_len)) {
/* Failed to use zero copy */
desc = NULL;
if (!virtqueue_read_indirect_desc(dev, desc_buf,
desc_addr,
desc_len)) {
desc = desc_buf;
}
}
if (!desc) {
vu_panic(dev, "Invalid indirect buffer table");
return;
}
i = 0;
}
do {
if (++num_bufs > max) {
vu_panic(dev, "Looped descriptor");
return;
}
if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
min = MIN(le32toh(desc[i].len), len);
vu_log_write(dev, le64toh(desc[i].addr), min);
len -= min;
}
} while (len > 0 &&
(virtqueue_read_next_desc(dev, desc, i, max, &i)
== VIRTQUEUE_READ_DESC_MORE));
}
void
vu_queue_fill(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem,
unsigned int len, unsigned int idx)
{
struct vring_used_elem uelem;
if (!vu_is_vq_usable(dev, vq)) {
return;
}
vu_log_queue_fill(dev, vq, elem, len);
idx = (idx + vq->used_idx) % vq->vring.num;
uelem.id = htole32(elem->index);
uelem.len = htole32(len);
vring_used_write(dev, vq, &uelem, idx);
}
static inline
void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
{
vq->vring.used->idx = htole16(val);
vu_log_write(dev,
vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
sizeof(vq->vring.used->idx));
vq->used_idx = val;
}
void
vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
{
uint16_t old, new;
if (!vu_is_vq_usable(dev, vq)) {
return;
}
/* Make sure buffer is written before we update index. */
smp_wmb();
old = vq->used_idx;
new = old + count;
vring_used_idx_set(dev, vq, new);
vq->inuse -= count;
if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
vq->signalled_used_valid = false;
}
}
void
vu_queue_push(VuDev *dev, VuVirtq *vq,
const VuVirtqElement *elem, unsigned int len)
{
vu_queue_fill(dev, vq, elem, len, 0);
vu_queue_inflight_pre_put(dev, vq, elem->index);
vu_queue_flush(dev, vq, 1);
vu_queue_inflight_post_put(dev, vq, elem->index);
}