xrdp/common/trans.c
matt335672 b1147f5faa CVE-2022-23479
Detect attempts to overflow input buffer

If application code hasn't properly sanitised the header_size
for a transport, it is possible for read requests to be issued
which overflow the input buffer. This change detects this
at a low level and bounces the read request.
2022-12-09 17:34:25 +00:00

1073 lines
27 KiB
C

/**
* xrdp: A Remote Desktop Protocol server.
*
* Copyright (C) Jay Sorg 2004-2014
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* generic transport
*/
#if defined(HAVE_CONFIG_H)
#include <config_ac.h>
#endif
#include "os_calls.h"
#include "string_calls.h"
#include "trans.h"
#include "arch.h"
#include "parse.h"
#include "ssl_calls.h"
#include "log.h"
#define MAX_SBYTES 0
/** Time between polls of is_term when connecting */
#define CONNECT_TERM_POLL_MS 3000
/** Time we wait before another connect() attempt if one fails immediately */
#define CONNECT_DELAY_ON_FAIL_MS 2000
/*****************************************************************************/
int
trans_tls_recv(struct trans *self, char *ptr, int len)
{
if (self->tls == NULL)
{
return 1;
}
return ssl_tls_read(self->tls, ptr, len);
}
/*****************************************************************************/
int
trans_tls_send(struct trans *self, const char *data, int len)
{
if (self->tls == NULL)
{
return 1;
}
return ssl_tls_write(self->tls, data, len);
}
/*****************************************************************************/
int
trans_tls_can_recv(struct trans *self, int sck, int millis)
{
if (self->tls == NULL)
{
return 1;
}
return ssl_tls_can_recv(self->tls, sck, millis);
}
/*****************************************************************************/
int
trans_tcp_recv(struct trans *self, char *ptr, int len)
{
return g_tcp_recv(self->sck, ptr, len, 0);
}
/*****************************************************************************/
int
trans_tcp_send(struct trans *self, const char *data, int len)
{
return g_tcp_send(self->sck, data, len, 0);
}
/*****************************************************************************/
int
trans_tcp_can_recv(struct trans *self, int sck, int millis)
{
return g_sck_can_recv(sck, millis);
}
/*****************************************************************************/
struct trans *
trans_create(int mode, int in_size, int out_size)
{
struct trans *self = (struct trans *) NULL;
self = (struct trans *) g_malloc(sizeof(struct trans), 1);
if (self != NULL)
{
self->sck = -1;
make_stream(self->in_s);
init_stream(self->in_s, in_size);
make_stream(self->out_s);
init_stream(self->out_s, out_size);
self->mode = mode;
self->tls = 0;
/* assign tcp calls by default */
self->trans_recv = trans_tcp_recv;
self->trans_send = trans_tcp_send;
self->trans_can_recv = trans_tcp_can_recv;
}
return self;
}
/*****************************************************************************/
void
trans_delete(struct trans *self)
{
if (self == 0)
{
return;
}
/* Call the user-specified destructor if one exists */
if (self->extra_destructor != NULL)
{
self->extra_destructor(self);
}
free_stream(self->in_s);
free_stream(self->out_s);
if (self->sck >= 0)
{
g_tcp_close(self->sck);
}
self->sck = -1;
if (self->listen_filename != 0)
{
g_file_delete(self->listen_filename);
g_free(self->listen_filename);
}
if (self->tls != 0)
{
ssl_tls_delete(self->tls);
}
g_free(self);
}
/*****************************************************************************/
void
trans_delete_from_child(struct trans *self)
{
if (self == 0)
{
return;
}
if (self->listen_filename != 0)
{
g_free(self->listen_filename);
self->listen_filename = 0;
}
trans_delete(self);
}
/*****************************************************************************/
int
trans_get_wait_objs(struct trans *self, tbus *objs, int *count)
{
if (self == 0)
{
return 1;
}
if (self->status != TRANS_STATUS_UP)
{
return 1;
}
objs[*count] = self->sck;
(*count)++;
if (self->tls != NULL && (objs[*count] = ssl_get_rwo(self->tls)) != 0)
{
(*count)++;
}
return 0;
}
/*****************************************************************************/
int
trans_get_wait_objs_rw(struct trans *self, tbus *robjs, int *rcount,
tbus *wobjs, int *wcount, int *timeout)
{
if (self == 0)
{
return 1;
}
if (self->status != TRANS_STATUS_UP)
{
return 1;
}
if ((self->si != 0) && (self->si->source[self->my_source] > MAX_SBYTES))
{
}
else
{
if (trans_get_wait_objs(self, robjs, rcount) != 0)
{
return 1;
}
}
if (self->wait_s != 0)
{
wobjs[*wcount] = self->sck;
(*wcount)++;
}
return 0;
}
/*****************************************************************************/
int
trans_send_waiting(struct trans *self, int block)
{
struct stream *temp_s;
int bytes;
int sent;
int timeout;
int cont;
timeout = block ? 100 : 0;
cont = 1;
while (cont)
{
if (self->wait_s != 0)
{
temp_s = self->wait_s;
if (g_tcp_can_send(self->sck, timeout))
{
bytes = (int) (temp_s->end - temp_s->p);
sent = self->trans_send(self, temp_s->p, bytes);
if (sent > 0)
{
temp_s->p += sent;
if (temp_s->source != 0)
{
temp_s->source[0] -= sent;
}
if (temp_s->p >= temp_s->end)
{
self->wait_s = temp_s->next;
free_stream(temp_s);
}
}
else if (sent == 0)
{
return 1;
}
else
{
if (!g_tcp_last_error_would_block(self->sck))
{
return 1;
}
}
}
else if (block)
{
/* check for term here */
if (self->is_term != 0)
{
if (self->is_term())
{
/* term */
return 1;
}
}
}
}
else
{
break;
}
cont = block;
}
return 0;
}
/*****************************************************************************/
int
trans_check_wait_objs(struct trans *self)
{
tbus in_sck = (tbus) 0;
struct trans *in_trans = (struct trans *) NULL;
int read_bytes = 0;
unsigned int to_read = 0;
unsigned int read_so_far = 0;
int rv = 0;
enum xrdp_source cur_source;
if (self == 0)
{
return 1;
}
if (self->status != TRANS_STATUS_UP)
{
return 1;
}
rv = 0;
if (self->type1 == TRANS_TYPE_LISTENER) /* listening */
{
if (g_sck_can_recv(self->sck, 0))
{
in_sck = g_sck_accept(self->sck);
if (in_sck == -1)
{
if (g_tcp_last_error_would_block(self->sck))
{
/* ok, but shouldn't happen */
}
else
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
if (in_sck != -1)
{
if (self->trans_conn_in != 0) /* is function assigned */
{
in_trans = trans_create(self->mode, self->in_s->size,
self->out_s->size);
in_trans->sck = in_sck;
in_trans->type1 = TRANS_TYPE_SERVER;
in_trans->status = TRANS_STATUS_UP;
in_trans->is_term = self->is_term;
g_sck_set_non_blocking(in_sck);
if (self->trans_conn_in(self, in_trans) != 0)
{
trans_delete(in_trans);
}
}
else
{
g_tcp_close(in_sck);
}
}
}
}
else /* connected server or client (2 or 3) */
{
if (self->si != 0 && self->si->source[self->my_source] > MAX_SBYTES)
{
}
else if (self->trans_can_recv(self, self->sck, 0))
{
/* CVE-2022-23479 - check a malicious caller hasn't managed
* to set the header_size to an unreasonable value */
if (self->header_size > (unsigned int)self->in_s->size)
{
LOG(LOG_LEVEL_ERROR,
"trans_check_wait_objs: Reading %u bytes beyond buffer",
self->header_size - (unsigned int)self->in_s->size);
self->status = TRANS_STATUS_DOWN;
return 1;
}
cur_source = XRDP_SOURCE_NONE;
if (self->si != 0)
{
cur_source = self->si->cur_source;
self->si->cur_source = self->my_source;
}
read_so_far = self->in_s->end - self->in_s->data;
to_read = self->header_size - read_so_far;
if (to_read > 0)
{
read_bytes = self->trans_recv(self, self->in_s->end, to_read);
if (read_bytes == -1)
{
if (g_tcp_last_error_would_block(self->sck))
{
/* ok, but shouldn't happen */
}
else
{
/* error */
self->status = TRANS_STATUS_DOWN;
if (self->si != 0)
{
self->si->cur_source = cur_source;
}
return 1;
}
}
else if (read_bytes == 0)
{
/* error */
self->status = TRANS_STATUS_DOWN;
if (self->si != 0)
{
self->si->cur_source = cur_source;
}
return 1;
}
else
{
self->in_s->end += read_bytes;
}
}
read_so_far = self->in_s->end - self->in_s->data;
if (read_so_far == self->header_size)
{
if (self->trans_data_in != 0)
{
rv = self->trans_data_in(self);
if (self->no_stream_init_on_data_in == 0)
{
init_stream(self->in_s, 0);
}
}
}
if (self->si != 0)
{
self->si->cur_source = cur_source;
}
}
if (trans_send_waiting(self, 0) != 0)
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
return rv;
}
/*****************************************************************************/
int
trans_force_read_s(struct trans *self, struct stream *in_s, int size)
{
int rcvd;
if (self->status != TRANS_STATUS_UP ||
size < 0 || !s_check_rem_out(in_s, size))
{
return 1;
}
while (size > 0)
{
rcvd = self->trans_recv(self, in_s->end, size);
if (rcvd == -1)
{
if (g_tcp_last_error_would_block(self->sck))
{
if (!self->trans_can_recv(self, self->sck, 100))
{
/* check for term here */
if (self->is_term != 0)
{
if (self->is_term())
{
/* term */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
}
}
else
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
else if (rcvd == 0)
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
else
{
in_s->end += rcvd;
size -= rcvd;
}
}
return 0;
}
/*****************************************************************************/
int
trans_force_read(struct trans *self, int size)
{
return trans_force_read_s(self, self->in_s, size);
}
/*****************************************************************************/
int
trans_force_write_s(struct trans *self, struct stream *out_s)
{
int size;
int total;
int sent;
if (self->status != TRANS_STATUS_UP)
{
return 1;
}
size = (int) (out_s->end - out_s->data);
total = 0;
if (trans_send_waiting(self, 1) != 0)
{
self->status = TRANS_STATUS_DOWN;
return 1;
}
while (total < size)
{
sent = self->trans_send(self, out_s->data + total, size - total);
if (sent == -1)
{
if (g_tcp_last_error_would_block(self->sck))
{
if (!g_tcp_can_send(self->sck, 100))
{
/* check for term here */
if (self->is_term != 0)
{
if (self->is_term())
{
/* term */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
}
}
else
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
}
else if (sent == 0)
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
else
{
total = total + sent;
}
}
return 0;
}
/*****************************************************************************/
int
trans_force_write(struct trans *self)
{
return trans_force_write_s(self, self->out_s);
}
/*****************************************************************************/
int
trans_write_copy_s(struct trans *self, struct stream *out_s)
{
int size;
int sent;
struct stream *wait_s;
struct stream *temp_s;
char *out_data;
if (self->status != TRANS_STATUS_UP)
{
return 1;
}
/* try to send any left over */
if (trans_send_waiting(self, 0) != 0)
{
/* error */
self->status = TRANS_STATUS_DOWN;
return 1;
}
out_data = out_s->data;
sent = 0;
size = (int) (out_s->end - out_s->data);
if (self->wait_s == 0)
{
/* if no left over, try to send this new data */
if (g_tcp_can_send(self->sck, 0))
{
sent = self->trans_send(self, out_s->data, size);
if (sent > 0)
{
out_data += sent;
size -= sent;
}
else if (sent == 0)
{
return 1;
}
else
{
if (!g_tcp_last_error_would_block(self->sck))
{
return 1;
}
}
}
}
if (size < 1)
{
return 0;
}
/* did not send right away, have to copy */
make_stream(wait_s);
init_stream(wait_s, size);
if (self->si != 0)
{
if ((self->si->cur_source != XRDP_SOURCE_NONE) &&
(self->si->cur_source != self->my_source))
{
self->si->source[self->si->cur_source] += size;
wait_s->source = self->si->source + self->si->cur_source;
}
}
out_uint8a(wait_s, out_data, size);
s_mark_end(wait_s);
wait_s->p = wait_s->data;
if (self->wait_s == 0)
{
self->wait_s = wait_s;
}
else
{
temp_s = self->wait_s;
while (temp_s->next != 0)
{
temp_s = temp_s->next;
}
temp_s->next = wait_s;
}
return 0;
}
/*****************************************************************************/
int
trans_write_copy(struct trans *self)
{
return trans_write_copy_s(self, self->out_s);
}
/*****************************************************************************/
/* Shim to apply the function signature of g_tcp_connect()
* to g_tcp_local_connect()
*/
static int
local_connect_shim(int fd, const char *server, const char *port)
{
return g_tcp_local_connect(fd, port);
}
/**************************************************************************//**
* Waits for an asynchronous connect to complete.
* @param self - Transport object
* @param start_time Start time of connect (from g_time3())
* @param timeout Total wait timeout
* @return 0 - connect succeeded, 1 - Connect failed
*
* If the transport is set up for checking a termination object, this
* on a regular basis.
*/
static int
poll_for_async_connect(struct trans *self, int start_time, int timeout)
{
int rv = 1;
int ms_remaining = timeout - (g_time3() - start_time);
while (ms_remaining > 0)
{
int poll_time = ms_remaining;
/* Lower bound for waiting for a result */
if (poll_time < 100)
{
poll_time = 100;
}
/* Limit the wait time if we need to poll for termination */
if (self->is_term != NULL && poll_time > CONNECT_TERM_POLL_MS)
{
poll_time = CONNECT_TERM_POLL_MS;
}
if (g_tcp_can_send(self->sck, poll_time))
{
/* Connect has finished - return the socket status */
rv = g_sck_socket_ok(self->sck) ? 0 : 1;
break;
}
/* Check for program termination */
if (self->is_term != NULL && self->is_term())
{
break;
}
ms_remaining = timeout - (g_time3() - start_time);
}
return rv;
}
/*****************************************************************************/
int
trans_connect(struct trans *self, const char *server, const char *port,
int timeout)
{
int start_time = g_time3();
int error;
int ms_before_next_connect;
/*
* Function pointers which we use in the main loop to avoid
* having to switch on the socket mode */
int (*f_alloc_socket)(void);
int (*f_connect)(int fd, const char *server, const char *port);
switch (self->mode)
{
case TRANS_MODE_TCP:
f_alloc_socket = g_tcp_socket;
f_connect = g_tcp_connect;
break;
case TRANS_MODE_UNIX:
f_alloc_socket = g_tcp_local_socket;
f_connect = local_connect_shim;
break;
default:
LOG(LOG_LEVEL_ERROR, "Bad socket mode %d", self->mode);
return 1;
}
while (1)
{
/* Check the program isn't terminating */
if (self->is_term != NULL && self->is_term())
{
error = 1;
break;
}
/* Allocate a new socket */
if (self->sck >= 0)
{
g_tcp_close(self->sck);
}
self->sck = f_alloc_socket();
if (self->sck < 0)
{
error = 1;
break;
}
/* Try to connect asynchronously */
g_tcp_set_non_blocking(self->sck);
error = f_connect(self->sck, server, port);
if (error == 0)
{
/* Connect was immediately successful */
break;
}
if (g_tcp_last_error_would_block(self->sck))
{
/* Async connect is in progress */
if (poll_for_async_connect(self, start_time, timeout) == 0)
{
/* Async connect was successful */
error = 0;
break;
}
/* No need to wait any more before the next connect attempt */
ms_before_next_connect = 0;
}
else
{
/* Connect failed immediately. Wait a bit before the next
* attempt */
ms_before_next_connect = CONNECT_DELAY_ON_FAIL_MS;
}
/* Have we reached the total timeout yet? */
int ms_left = timeout - (g_time3() - start_time);
if (ms_left <= 0)
{
error = 1;
break;
}
/* Sleep a bit before trying again */
if (ms_before_next_connect > 0)
{
if (ms_before_next_connect > ms_left)
{
ms_before_next_connect = ms_left;
}
g_sleep(ms_before_next_connect);
}
}
if (error != 0)
{
if (self->sck >= 0)
{
g_tcp_close(self->sck);
self->sck = -1;
}
self->status = TRANS_STATUS_DOWN;
}
else
{
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_CLIENT; /* client */
}
return error;
}
/*****************************************************************************/
/**
* @return 0 on success, 1 on failure
*/
int
trans_listen_address(struct trans *self, const char *port, const char *address)
{
if (self->sck >= 0)
{
g_tcp_close(self->sck);
}
if (self->mode == TRANS_MODE_TCP) /* tcp */
{
self->sck = g_tcp_socket();
if (self->sck < 0)
{
return 1;
}
g_tcp_set_non_blocking(self->sck);
if (g_tcp_bind_address(self->sck, port, address) == 0)
{
if (g_tcp_listen(self->sck) == 0)
{
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_LISTENER; /* listener */
return 0;
}
}
}
else if (self->mode == TRANS_MODE_UNIX) /* unix socket */
{
g_free(self->listen_filename);
self->listen_filename = 0;
g_file_delete(port);
self->sck = g_tcp_local_socket();
if (self->sck < 0)
{
return 1;
}
g_tcp_set_non_blocking(self->sck);
if (g_tcp_local_bind(self->sck, port) == 0)
{
self->listen_filename = g_strdup(port);
if (g_tcp_listen(self->sck) == 0)
{
g_chmod_hex(port, 0x0660);
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_LISTENER; /* listener */
return 0;
}
}
}
else if (self->mode == TRANS_MODE_VSOCK) /* vsock socket */
{
self->sck = g_sck_vsock_socket();
if (self->sck < 0)
{
return 1;
}
g_tcp_set_non_blocking(self->sck);
if (g_sck_vsock_bind_address(self->sck, port, address) == 0)
{
if (g_tcp_listen(self->sck) == 0)
{
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_LISTENER; /* listener */
return 0;
}
}
}
else if (self->mode == TRANS_MODE_TCP4) /* tcp4 */
{
self->sck = g_tcp4_socket();
if (self->sck < 0)
{
return 1;
}
g_tcp_set_non_blocking(self->sck);
if (g_tcp4_bind_address(self->sck, port, address) == 0)
{
if (g_tcp_listen(self->sck) == 0)
{
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_LISTENER; /* listener */
return 0;
}
}
}
else if (self->mode == TRANS_MODE_TCP6) /* tcp6 */
{
self->sck = g_tcp6_socket();
if (self->sck < 0)
{
return 1;
}
g_tcp_set_non_blocking(self->sck);
if (g_tcp6_bind_address(self->sck, port, address) == 0)
{
if (g_tcp_listen(self->sck) == 0)
{
self->status = TRANS_STATUS_UP; /* ok */
self->type1 = TRANS_TYPE_LISTENER; /* listener */
return 0;
}
}
}
return 1;
}
/*****************************************************************************/
int
trans_listen(struct trans *self, const char *port)
{
return trans_listen_address(self, port, "0.0.0.0");
}
/*****************************************************************************/
struct stream *
trans_get_in_s(struct trans *self)
{
struct stream *rv = (struct stream *) NULL;
if (self == NULL)
{
rv = (struct stream *) NULL;
}
else
{
rv = self->in_s;
}
return rv;
}
/*****************************************************************************/
struct stream *
trans_get_out_s(struct trans *self, int size)
{
struct stream *rv = (struct stream *) NULL;
if (self == NULL)
{
rv = (struct stream *) NULL;
}
else
{
init_stream(self->out_s, size);
rv = self->out_s;
}
return rv;
}
/*****************************************************************************/
/* returns error */
int
trans_set_tls_mode(struct trans *self, const char *key, const char *cert,
long ssl_protocols, const char *tls_ciphers)
{
self->tls = ssl_tls_create(self, key, cert);
if (self->tls == NULL)
{
LOG(LOG_LEVEL_ERROR, "trans_set_tls_mode: ssl_tls_create malloc error");
return 1;
}
if (ssl_tls_accept(self->tls, ssl_protocols, tls_ciphers) != 0)
{
LOG(LOG_LEVEL_ERROR, "trans_set_tls_mode: ssl_tls_accept failed");
return 1;
}
/* assign tls functions */
self->trans_recv = trans_tls_recv;
self->trans_send = trans_tls_send;
self->trans_can_recv = trans_tls_can_recv;
self->ssl_protocol = ssl_get_version(self->tls);
self->cipher_name = ssl_get_cipher_name(self->tls);
return 0;
}
/*****************************************************************************/
/* returns error */
int
trans_shutdown_tls_mode(struct trans *self)
{
if (self->tls != NULL)
{
return ssl_tls_disconnect(self->tls);
}
/* assign callback back to tcp cal */
self->trans_recv = trans_tcp_recv;
self->trans_send = trans_tcp_send;
self->trans_can_recv = trans_tcp_can_recv;
return 0;
}