diff --git a/nbd/server.c b/nbd/server.c index 431e543b47..bfa8d47dad 100644 --- a/nbd/server.c +++ b/nbd/server.c @@ -2750,22 +2750,48 @@ static void nbd_client_receive_next_request(NBDClient *client) } } +static void nbd_handshake_timer_cb(void *opaque) +{ + QIOChannel *ioc = opaque; + + trace_nbd_handshake_timer_cb(); + qio_channel_shutdown(ioc, QIO_CHANNEL_SHUTDOWN_BOTH, NULL); +} + static coroutine_fn void nbd_co_client_start(void *opaque) { NBDClient *client = opaque; Error *local_err = NULL; + QEMUTimer *handshake_timer = NULL; qemu_co_mutex_init(&client->send_lock); - /* TODO - utilize client->handshake_max_secs */ + /* + * Create a timer to bound the time spent in negotiation. If the + * timer expires, it is likely nbd_negotiate will fail because the + * socket was shutdown. + */ + if (client->handshake_max_secs > 0) { + handshake_timer = aio_timer_new(qemu_get_aio_context(), + QEMU_CLOCK_REALTIME, + SCALE_NS, + nbd_handshake_timer_cb, + client->sioc); + timer_mod(handshake_timer, + qemu_clock_get_ns(QEMU_CLOCK_REALTIME) + + client->handshake_max_secs * NANOSECONDS_PER_SECOND); + } + if (nbd_negotiate(client, &local_err)) { if (local_err) { error_report_err(local_err); } + timer_free(handshake_timer); client_close(client, false); return; } + timer_free(handshake_timer); nbd_client_receive_next_request(client); } diff --git a/nbd/trace-events b/nbd/trace-events index b7032ca277..675f880fa1 100644 --- a/nbd/trace-events +++ b/nbd/trace-events @@ -73,6 +73,7 @@ nbd_co_receive_request_decode_type(uint64_t handle, uint16_t type, const char *n nbd_co_receive_request_payload_received(uint64_t handle, uint32_t len) "Payload received: handle = %" PRIu64 ", len = %" PRIu32 nbd_co_receive_align_compliance(const char *op, uint64_t from, uint32_t len, uint32_t align) "client sent non-compliant unaligned %s request: from=0x%" PRIx64 ", len=0x%" PRIx32 ", align=0x%" PRIx32 nbd_trip(void) "Reading request" +nbd_handshake_timer_cb(void) "client took too long to negotiate" # client-connection.c nbd_connect_thread_sleep(uint64_t timeout) "timeout %" PRIu64