diff --git a/hw/virtio/vhost-user.c b/hw/virtio/vhost-user.c index 9334a8ae22..32a95a8c69 100644 --- a/hw/virtio/vhost-user.c +++ b/hw/virtio/vhost-user.c @@ -163,22 +163,26 @@ fail: } static int process_message_reply(struct vhost_dev *dev, - VhostUserRequest request) + VhostUserMsg msg) { - VhostUserMsg msg; + VhostUserMsg msg_reply; - if (vhost_user_read(dev, &msg) < 0) { + if ((msg.flags & VHOST_USER_NEED_REPLY_MASK) == 0) { + return 0; + } + + if (vhost_user_read(dev, &msg_reply) < 0) { return -1; } - if (msg.request != request) { + if (msg_reply.request != msg.request) { error_report("Received unexpected msg type." "Expected %d received %d", - request, msg.request); + msg.request, msg_reply.request); return -1; } - return msg.payload.u64 ? -1 : 0; + return msg_reply.payload.u64 ? -1 : 0; } static bool vhost_user_one_time_request(VhostUserRequest request) @@ -208,6 +212,7 @@ static int vhost_user_write(struct vhost_dev *dev, VhostUserMsg *msg, * request, we just ignore it. */ if (vhost_user_one_time_request(msg->request) && dev->vq_index != 0) { + msg->flags &= ~VHOST_USER_NEED_REPLY_MASK; return 0; } @@ -320,7 +325,7 @@ static int vhost_user_set_mem_table(struct vhost_dev *dev, } if (reply_supported) { - return process_message_reply(dev, msg.request); + return process_message_reply(dev, msg); } return 0; @@ -712,7 +717,7 @@ static int vhost_user_net_set_mtu(struct vhost_dev *dev, uint16_t mtu) /* If reply_ack supported, slave has to ack specified MTU is valid */ if (reply_supported) { - return process_message_reply(dev, msg.request); + return process_message_reply(dev, msg); } return 0;