diff --git a/libfreerdp-core/activation.c b/libfreerdp-core/activation.c index 95cf9d9da..5121dff7f 100644 --- a/libfreerdp-core/activation.c +++ b/libfreerdp-core/activation.c @@ -170,7 +170,6 @@ void rdp_send_client_font_list_pdu(rdpRdp* rdp, uint16 flags) void rdp_recv_server_font_map_pdu(rdpRdp* rdp, STREAM* s, rdpSettings* settings) { rdp->activated = True; - rdp->transport->tcp->set_blocking_mode(rdp->transport->tcp, False); } void rdp_recv_deactivate_all(rdpRdp* rdp, STREAM* s) @@ -184,6 +183,5 @@ void rdp_recv_deactivate_all(rdpRdp* rdp, STREAM* s) stream_seek(s, lengthSourceDescriptor); /* sourceDescriptor (should be 0x00) */ rdp->activated = False; - rdp->transport->tcp->set_blocking_mode(rdp->transport->tcp, True); } diff --git a/libfreerdp-core/connection.c b/libfreerdp-core/connection.c index 5b027af9e..e8934077f 100644 --- a/libfreerdp-core/connection.c +++ b/libfreerdp-core/connection.c @@ -95,6 +95,7 @@ boolean rdp_client_connect(rdpRdp* rdp) rdp->licensed = True; rdp_client_activate(rdp); + rdp_set_blocking_mode(rdp, False); return True; } diff --git a/libfreerdp-core/freerdp.c b/libfreerdp-core/freerdp.c index ce75572dc..d46c9be7b 100644 --- a/libfreerdp-core/freerdp.c +++ b/libfreerdp-core/freerdp.c @@ -54,10 +54,13 @@ boolean freerdp_get_fds(freerdp* instance, void** rfds, int* rcount, void** wfds boolean freerdp_check_fds(freerdp* instance) { rdpRdp* rdp; + int status; rdp = (rdpRdp*) instance->rdp; - rdp_recv(rdp); + status = rdp_check_fds(rdp); + if (status < 0) + return False; return True; } diff --git a/libfreerdp-core/rdp.c b/libfreerdp-core/rdp.c index 3729ef9a0..4948f964e 100644 --- a/libfreerdp-core/rdp.c +++ b/libfreerdp-core/rdp.c @@ -344,13 +344,13 @@ void rdp_read_data_pdu(rdpRdp* rdp, STREAM* s) } /** - * Receive an RDP packet.\n + * Process an RDP packet.\n * @param rdp RDP module + * @param s stream */ -void rdp_recv(rdpRdp* rdp) +void rdp_process_pdu(rdpRdp* rdp, STREAM* s) { - STREAM* s; int length; uint16 pduType; uint16 pduLength; @@ -359,9 +359,6 @@ void rdp_recv(rdpRdp* rdp) uint16 sec_flags; enum DomainMCSPDU MCSPDU; - s = transport_recv_stream_init(rdp->transport, 4096); - transport_read(rdp->transport, s); - MCSPDU = DomainMCSPDU_SendDataIndication; mcs_read_domain_mcspdu_header(s, &MCSPDU, &length); @@ -421,6 +418,47 @@ void rdp_recv(rdpRdp* rdp) } } +/** + * Receive an RDP packet.\n + * @param rdp RDP module + */ + +void rdp_recv(rdpRdp* rdp) +{ + STREAM* s; + + s = transport_recv_stream_init(rdp->transport, 4096); + transport_read(rdp->transport, s); + + rdp_process_pdu(rdp, s); +} + +static int rdp_recv_callback(rdpTransport* transport, STREAM* s, void* extra) +{ + rdpRdp* rdp = (rdpRdp*) extra; + + rdp_process_pdu(rdp, s); + + return 1; +} + +/** + * Set non-blocking mode information. + * @param rdp RDP module + * @param blocking blocking mode + */ +void rdp_set_blocking_mode(rdpRdp* rdp, boolean blocking) +{ + rdp->transport->recv_callback = rdp_recv_callback; + rdp->transport->recv_extra = rdp; + transport_set_blocking_mode(rdp->transport, blocking); +} + +int rdp_check_fds(rdpRdp* rdp) +{ + return transport_check_fds(rdp->transport); +} + /** * Instantiate new RDP module. * @return new RDP module diff --git a/libfreerdp-core/rdp.h b/libfreerdp-core/rdp.h index 162c40ad7..52b0b38a7 100644 --- a/libfreerdp-core/rdp.h +++ b/libfreerdp-core/rdp.h @@ -236,6 +236,9 @@ void rdp_send_data_pdu(rdpRdp* rdp, STREAM* s, uint16 type, uint16 channel_id); void rdp_send(rdpRdp* rdp, STREAM* s); void rdp_recv(rdpRdp* rdp); +void rdp_set_blocking_mode(rdpRdp* rdp, boolean blocking); +int rdp_check_fds(rdpRdp* rdp); + rdpRdp* rdp_new(); void rdp_free(rdpRdp* rdp); diff --git a/libfreerdp-core/tcp.c b/libfreerdp-core/tcp.c index 7ac99ec54..090ca7f55 100644 --- a/libfreerdp-core/tcp.c +++ b/libfreerdp-core/tcp.c @@ -164,28 +164,18 @@ int tcp_read(rdpTcp* tcp, uint8* data, int length) int tcp_write(rdpTcp* tcp, uint8* data, int length) { int status; - int sent = 0; - while (sent < length) + status = send(tcp->sockfd, data, length, MSG_NOSIGNAL); + + if (status < 0) { - status = send(tcp->sockfd, data, (length - sent), MSG_NOSIGNAL); - - if (status < 0) - { - if (errno == EAGAIN || errno == EWOULDBLOCK) - continue; - - perror("send"); - return -1; - } + if (errno == EAGAIN || errno == EWOULDBLOCK) + status = 0; else - { - sent += status; - data += status; - } + perror("send"); } - return sent; + return status; } boolean tcp_disconnect(rdpTcp * tcp) diff --git a/libfreerdp-core/tls.c b/libfreerdp-core/tls.c index 241874c1d..9f6fb3f4d 100644 --- a/libfreerdp-core/tls.c +++ b/libfreerdp-core/tls.c @@ -77,58 +77,46 @@ int tls_read(rdpTls* tls, uint8* data, int length) { int status; - while (True) + status = SSL_read(tls->ssl, data, length); + + switch (SSL_get_error(tls->ssl, status)) { - status = SSL_read(tls->ssl, data, length); + case SSL_ERROR_NONE: + break; - switch (SSL_get_error(tls->ssl, status)) - { - case SSL_ERROR_NONE: - return status; - break; + case SSL_ERROR_WANT_READ: + status = 0; + break; - case SSL_ERROR_WANT_READ: - nanosleep(&tls->ts, NULL); - break; - - default: - //tls_print_error("SSL_read", tls->ssl, status); - return -1; - break; - } + default: + status = -1; + break; } - return 0; + return status; } int tls_write(rdpTls* tls, uint8* data, int length) { int status; - int sent = 0; - while (sent < length) + status = SSL_write(tls->ssl, data, length); + + switch (SSL_get_error(tls->ssl, status)) { - status = SSL_write(tls->ssl, data, length); + case SSL_ERROR_NONE: + break; - switch (SSL_get_error(tls->ssl, status)) - { - case SSL_ERROR_NONE: - sent += status; - data += status; - break; + case SSL_ERROR_WANT_WRITE: + status = 0; + break; - case SSL_ERROR_WANT_WRITE: - nanosleep(&tls->ts, NULL); - break; - - default: - tls_print_error("SSL_write", tls->ssl, status); - return -1; - break; - } + default: + tls_print_error("SSL_write", tls->ssl, status); + status = -1; } - return sent; + return status; } boolean tls_print_error(char *func, SSL *connection, int value) @@ -214,10 +202,6 @@ rdpTls* tls_new() */ SSL_CTX_set_options(tls->ctx, SSL_OP_ALL); - - /* a small 0.1ms delay when network blocking happens. */ - tls->ts.tv_sec = 0; - tls->ts.tv_nsec = 100000; } return tls; diff --git a/libfreerdp-core/transport.c b/libfreerdp-core/transport.c index 45cebd77c..6e329e240 100644 --- a/libfreerdp-core/transport.c +++ b/libfreerdp-core/transport.c @@ -127,10 +127,36 @@ int transport_read(rdpTransport* transport, STREAM* s) { int status = -1; - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_read(transport->tls, s->data, s->size); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_read(transport->tcp, s->data, s->size); + while (True) + { + if (transport->layer == TRANSPORT_LAYER_TLS) + status = tls_read(transport->tls, s->data, s->size); + else if (transport->layer == TRANSPORT_LAYER_TCP) + status = tcp_read(transport->tcp, s->data, s->size); + + if (status == 0 && transport->blocking) + { + nanosleep(&transport->ts, NULL); + continue; + } + + break; + } + + return status; +} + +static int transport_read_nonblocking(rdpTransport* transport) +{ + int status; + + stream_check_size(transport->recv_buffer, 4096); + status = transport_read(transport, transport->recv_buffer); + + if (status <= 0) + return status; + + stream_seek(transport->recv_buffer, status); return status; } @@ -138,17 +164,100 @@ int transport_read(rdpTransport* transport, STREAM* s) int transport_write(rdpTransport* transport, STREAM* s) { int status = -1; + int length; + int sent = 0; - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_write(transport->tls, s->data, stream_get_length(s)); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_write(transport->tcp, s->data, stream_get_length(s)); + length = stream_get_length(s); + stream_set_pos(s, 0); + while (sent < length) + { + if (transport->layer == TRANSPORT_LAYER_TLS) + status = tls_write(transport->tls, stream_get_tail(s), length); + else if (transport->layer == TRANSPORT_LAYER_TCP) + status = tcp_write(transport->tcp, stream_get_tail(s), length); + + if (status < 0) + break; /* error occurred */ + + if (status == 0) + { + /* blocking while sending */ + nanosleep(&transport->ts, NULL); + + /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */ + if (!transport->blocking) + transport_read_nonblocking(transport); + } + + sent += status; + stream_seek(s, status); + } + + if (!transport->blocking) + transport_check_fds(transport); return status; } int transport_check_fds(rdpTransport* transport) { + int pos; + int status; + uint8 header; + uint16 length; + STREAM* received; + + status = transport_read_nonblocking(transport); + if (status <= 0) + return status; + + while ((pos = stream_get_pos(transport->recv_buffer)) > 0) + { + /* Ensure the TPKT or Fast Path header is available. */ + if (pos <= 4) + return 0; + + stream_set_pos(transport->recv_buffer, 0); + stream_peek_uint8(transport->recv_buffer, header); + if (header == 0x03) /* TPKT */ + length = tpkt_read_header(transport->recv_buffer); + else /* TODO: Fast Path */ + length = 0; + + if (length == 0) + { + printf("transport_check_fds: protocol error, not a TPKT header (%d).\n", header); + return -1; + } + + if (pos < length) + { + stream_set_pos(transport->recv_buffer, pos); + return 0; /* Packet is not yet completely received. */ + } + + /* + * A complete packet has been received. In case there are trailing data + * for the next packet, we copy it to the new receive buffer. + */ + received = transport->recv_buffer; + transport->recv_buffer = stream_new(BUFFER_SIZE); + + if (pos > length) + { + stream_set_pos(received, length); + stream_check_size(transport->recv_buffer, pos - length); + stream_copy(transport->recv_buffer, received, pos - length); + } + + stream_set_pos(received, 0); + status = transport->recv_callback(transport, received, transport->recv_extra); + stream_free(received); + + if (status < 0) + return status; + } + return 0; } @@ -158,6 +267,12 @@ void transport_init(rdpTransport* transport) transport->state = TRANSPORT_STATE_NEGO; } +boolean transport_set_blocking_mode(rdpTransport* transport, boolean blocking) +{ + transport->blocking = blocking; + return transport->tcp->set_blocking_mode(transport->tcp, blocking); +} + rdpTransport* transport_new(rdpSettings* settings) { rdpTransport* transport; @@ -179,6 +294,8 @@ rdpTransport* transport_new(rdpSettings* settings) /* buffers for blocking read/write */ transport->recv_stream = stream_new(BUFFER_SIZE); transport->send_stream = stream_new(BUFFER_SIZE); + + transport->blocking = True; } return transport; diff --git a/libfreerdp-core/transport.h b/libfreerdp-core/transport.h index 23c97cb7b..0f591abcd 100644 --- a/libfreerdp-core/transport.h +++ b/libfreerdp-core/transport.h @@ -63,6 +63,7 @@ struct rdp_transport void* recv_extra; STREAM* recv_buffer; TransportRecv recv_callback; + boolean blocking; }; STREAM* transport_recv_stream_init(rdpTransport* transport, int size); @@ -75,6 +76,7 @@ boolean transport_connect_nla(rdpTransport* transport); int transport_read(rdpTransport* transport, STREAM* s); int transport_write(rdpTransport* transport, STREAM* s); int transport_check_fds(rdpTransport* transport); +boolean transport_set_blocking_mode(rdpTransport* transport, boolean blocking); rdpTransport* transport_new(rdpSettings* settings); void transport_free(rdpTransport* transport);