/** * xrdp: A Remote Desktop Protocol server. * * Copyright (C) Jay Sorg 2004-2014 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * generic transport */ #if defined(HAVE_CONFIG_H) #include #endif #include "os_calls.h" #include "string_calls.h" #include "trans.h" #include "arch.h" #include "parse.h" #include "ssl_calls.h" #include "log.h" #define MAX_SBYTES 0 /** Time between polls of is_term when connecting */ #define CONNECT_TERM_POLL_MS 3000 /** Time we wait before another connect() attempt if one fails immediately */ #define CONNECT_DELAY_ON_FAIL_MS 2000 /*****************************************************************************/ int trans_tls_recv(struct trans *self, char *ptr, int len) { if (self->tls == NULL) { return 1; } return ssl_tls_read(self->tls, ptr, len); } /*****************************************************************************/ int trans_tls_send(struct trans *self, const char *data, int len) { if (self->tls == NULL) { return 1; } return ssl_tls_write(self->tls, data, len); } /*****************************************************************************/ int trans_tls_can_recv(struct trans *self, int sck, int millis) { if (self->tls == NULL) { return 1; } return ssl_tls_can_recv(self->tls, sck, millis); } /*****************************************************************************/ int trans_tcp_recv(struct trans *self, char *ptr, int len) { return g_tcp_recv(self->sck, ptr, len, 0); } /*****************************************************************************/ int trans_tcp_send(struct trans *self, const char *data, int len) { return g_tcp_send(self->sck, data, len, 0); } /*****************************************************************************/ int trans_tcp_can_recv(struct trans *self, int sck, int millis) { return g_sck_can_recv(sck, millis); } /*****************************************************************************/ struct trans * trans_create(int mode, int in_size, int out_size) { struct trans *self = (struct trans *) NULL; self = (struct trans *) g_malloc(sizeof(struct trans), 1); if (self != NULL) { self->sck = -1; make_stream(self->in_s); init_stream(self->in_s, in_size); make_stream(self->out_s); init_stream(self->out_s, out_size); self->mode = mode; self->tls = 0; /* assign tcp calls by default */ self->trans_recv = trans_tcp_recv; self->trans_send = trans_tcp_send; self->trans_can_recv = trans_tcp_can_recv; } return self; } /*****************************************************************************/ void trans_delete(struct trans *self) { if (self == 0) { return; } /* Call the user-specified destructor if one exists */ if (self->extra_destructor != NULL) { self->extra_destructor(self); } free_stream(self->in_s); free_stream(self->out_s); if (self->sck >= 0) { g_tcp_close(self->sck); } self->sck = -1; if (self->listen_filename != 0) { g_file_delete(self->listen_filename); g_free(self->listen_filename); } if (self->tls != 0) { ssl_tls_delete(self->tls); } g_free(self); } /*****************************************************************************/ void trans_delete_from_child(struct trans *self) { if (self == 0) { return; } if (self->listen_filename != 0) { g_free(self->listen_filename); self->listen_filename = 0; } trans_delete(self); } /*****************************************************************************/ int trans_get_wait_objs(struct trans *self, tbus *objs, int *count) { if (self == 0) { return 1; } if (self->status != TRANS_STATUS_UP) { return 1; } objs[*count] = self->sck; (*count)++; if (self->tls != NULL && (objs[*count] = ssl_get_rwo(self->tls)) != 0) { (*count)++; } return 0; } /*****************************************************************************/ int trans_get_wait_objs_rw(struct trans *self, tbus *robjs, int *rcount, tbus *wobjs, int *wcount, int *timeout) { if (self == 0) { return 1; } if (self->status != TRANS_STATUS_UP) { return 1; } if ((self->si != 0) && (self->si->source[self->my_source] > MAX_SBYTES)) { } else { if (trans_get_wait_objs(self, robjs, rcount) != 0) { return 1; } } if (self->wait_s != 0) { wobjs[*wcount] = self->sck; (*wcount)++; } return 0; } /*****************************************************************************/ int trans_send_waiting(struct trans *self, int block) { struct stream *temp_s; int bytes; int sent; int timeout; int cont; timeout = block ? 100 : 0; cont = 1; while (cont) { if (self->wait_s != 0) { temp_s = self->wait_s; if (g_tcp_can_send(self->sck, timeout)) { bytes = (int) (temp_s->end - temp_s->p); sent = self->trans_send(self, temp_s->p, bytes); if (sent > 0) { temp_s->p += sent; if (temp_s->source != 0) { temp_s->source[0] -= sent; } if (temp_s->p >= temp_s->end) { self->wait_s = temp_s->next; free_stream(temp_s); } } else if (sent == 0) { return 1; } else { if (!g_tcp_last_error_would_block(self->sck)) { return 1; } } } else if (block) { /* check for term here */ if (self->is_term != 0) { if (self->is_term()) { /* term */ return 1; } } } } else { break; } cont = block; } return 0; } /*****************************************************************************/ int trans_check_wait_objs(struct trans *self) { tbus in_sck = (tbus) 0; struct trans *in_trans = (struct trans *) NULL; int read_bytes = 0; int to_read = 0; int read_so_far = 0; int rv = 0; enum xrdp_source cur_source; if (self == 0) { return 1; } if (self->status != TRANS_STATUS_UP) { return 1; } rv = 0; if (self->type1 == TRANS_TYPE_LISTENER) /* listening */ { if (g_sck_can_recv(self->sck, 0)) { in_sck = g_sck_accept(self->sck); if (in_sck == -1) { if (g_tcp_last_error_would_block(self->sck)) { /* ok, but shouldn't happen */ } else { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } } if (in_sck != -1) { if (self->trans_conn_in != 0) /* is function assigned */ { in_trans = trans_create(self->mode, self->in_s->size, self->out_s->size); in_trans->sck = in_sck; in_trans->type1 = TRANS_TYPE_SERVER; in_trans->status = TRANS_STATUS_UP; in_trans->is_term = self->is_term; g_sck_set_non_blocking(in_sck); if (self->trans_conn_in(self, in_trans) != 0) { trans_delete(in_trans); } } else { g_tcp_close(in_sck); } } } } else /* connected server or client (2 or 3) */ { if (self->si != 0 && self->si->source[self->my_source] > MAX_SBYTES) { } else if (self->trans_can_recv(self, self->sck, 0)) { cur_source = XRDP_SOURCE_NONE; if (self->si != 0) { cur_source = self->si->cur_source; self->si->cur_source = self->my_source; } read_so_far = (int) (self->in_s->end - self->in_s->data); to_read = self->header_size - read_so_far; if (to_read > 0) { read_bytes = self->trans_recv(self, self->in_s->end, to_read); if (read_bytes == -1) { if (g_tcp_last_error_would_block(self->sck)) { /* ok, but shouldn't happen */ } else { /* error */ self->status = TRANS_STATUS_DOWN; if (self->si != 0) { self->si->cur_source = cur_source; } return 1; } } else if (read_bytes == 0) { /* error */ self->status = TRANS_STATUS_DOWN; if (self->si != 0) { self->si->cur_source = cur_source; } return 1; } else { self->in_s->end += read_bytes; } } read_so_far = (int) (self->in_s->end - self->in_s->data); if (read_so_far == self->header_size) { if (self->trans_data_in != 0) { rv = self->trans_data_in(self); if (self->no_stream_init_on_data_in == 0) { init_stream(self->in_s, 0); } } } if (self->si != 0) { self->si->cur_source = cur_source; } } if (trans_send_waiting(self, 0) != 0) { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } } return rv; } /*****************************************************************************/ int trans_force_read_s(struct trans *self, struct stream *in_s, int size) { int rcvd; if (self->status != TRANS_STATUS_UP || size < 0 || !s_check_rem_out(in_s, size)) { return 1; } while (size > 0) { rcvd = self->trans_recv(self, in_s->end, size); if (rcvd == -1) { if (g_tcp_last_error_would_block(self->sck)) { if (!self->trans_can_recv(self, self->sck, 100)) { /* check for term here */ if (self->is_term != 0) { if (self->is_term()) { /* term */ self->status = TRANS_STATUS_DOWN; return 1; } } } } else { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } } else if (rcvd == 0) { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } else { in_s->end += rcvd; size -= rcvd; } } return 0; } /*****************************************************************************/ int trans_force_read(struct trans *self, int size) { return trans_force_read_s(self, self->in_s, size); } /*****************************************************************************/ int trans_force_write_s(struct trans *self, struct stream *out_s) { int size; int total; int sent; if (self->status != TRANS_STATUS_UP) { return 1; } size = (int) (out_s->end - out_s->data); total = 0; if (trans_send_waiting(self, 1) != 0) { self->status = TRANS_STATUS_DOWN; return 1; } while (total < size) { sent = self->trans_send(self, out_s->data + total, size - total); if (sent == -1) { if (g_tcp_last_error_would_block(self->sck)) { if (!g_tcp_can_send(self->sck, 100)) { /* check for term here */ if (self->is_term != 0) { if (self->is_term()) { /* term */ self->status = TRANS_STATUS_DOWN; return 1; } } } } else { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } } else if (sent == 0) { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } else { total = total + sent; } } return 0; } /*****************************************************************************/ int trans_force_write(struct trans *self) { return trans_force_write_s(self, self->out_s); } /*****************************************************************************/ int trans_write_copy_s(struct trans *self, struct stream *out_s) { int size; int sent; struct stream *wait_s; struct stream *temp_s; char *out_data; if (self->status != TRANS_STATUS_UP) { return 1; } /* try to send any left over */ if (trans_send_waiting(self, 0) != 0) { /* error */ self->status = TRANS_STATUS_DOWN; return 1; } out_data = out_s->data; sent = 0; size = (int) (out_s->end - out_s->data); if (self->wait_s == 0) { /* if no left over, try to send this new data */ if (g_tcp_can_send(self->sck, 0)) { sent = self->trans_send(self, out_s->data, size); if (sent > 0) { out_data += sent; size -= sent; } else if (sent == 0) { return 1; } else { if (!g_tcp_last_error_would_block(self->sck)) { return 1; } } } } if (size < 1) { return 0; } /* did not send right away, have to copy */ make_stream(wait_s); init_stream(wait_s, size); if (self->si != 0) { if ((self->si->cur_source != XRDP_SOURCE_NONE) && (self->si->cur_source != self->my_source)) { self->si->source[self->si->cur_source] += size; wait_s->source = self->si->source + self->si->cur_source; } } out_uint8a(wait_s, out_data, size); s_mark_end(wait_s); wait_s->p = wait_s->data; if (self->wait_s == 0) { self->wait_s = wait_s; } else { temp_s = self->wait_s; while (temp_s->next != 0) { temp_s = temp_s->next; } temp_s->next = wait_s; } return 0; } /*****************************************************************************/ int trans_write_copy(struct trans *self) { return trans_write_copy_s(self, self->out_s); } /*****************************************************************************/ /* Shim to apply the function signature of g_tcp_connect() * to g_tcp_local_connect() */ static int local_connect_shim(int fd, const char *server, const char *port) { return g_tcp_local_connect(fd, port); } /**************************************************************************//** * Waits for an asynchronous connect to complete. * @param self - Transport object * @param start_time Start time of connect (from g_time3()) * @param timeout Total wait timeout * @return 0 - connect succeeded, 1 - Connect failed * * If the transport is set up for checking a termination object, this * on a regular basis. */ static int poll_for_async_connect(struct trans *self, int start_time, int timeout) { int rv = 1; int ms_remaining = timeout - (g_time3() - start_time); while (ms_remaining > 0) { int poll_time = ms_remaining; /* Lower bound for waititng for a result */ if (poll_time < 100) { poll_time = 100; } /* Limit the wait time if we need to poll for termination */ if (self->is_term != NULL && poll_time > CONNECT_TERM_POLL_MS) { poll_time = CONNECT_TERM_POLL_MS; } if (g_tcp_can_send(self->sck, poll_time)) { /* Connect has finished - return the socket status */ rv = g_sck_socket_ok(self->sck) ? 0 : 1; break; } /* Check for program termination */ if (self->is_term != NULL && self->is_term()) { break; } ms_remaining = timeout - (g_time3() - start_time); } return rv; } /*****************************************************************************/ int trans_connect(struct trans *self, const char *server, const char *port, int timeout) { int start_time = g_time3(); int error; int ms_before_next_connect; /* * Function pointers which we use in the main loop to avoid * having to switch on the socket mode */ int (*f_alloc_socket)(void); int (*f_connect)(int fd, const char *server, const char *port); switch (self->mode) { case TRANS_MODE_TCP: f_alloc_socket = g_tcp_socket; f_connect = g_tcp_connect; break; case TRANS_MODE_UNIX: f_alloc_socket = g_tcp_local_socket; f_connect = local_connect_shim; break; default: LOG(LOG_LEVEL_ERROR, "Bad socket mode %d", self->mode); return 1; } while (1) { /* Check the program isn't terminating */ if (self->is_term != NULL && self->is_term()) { error = 1; break; } /* Allocate a new socket */ if (self->sck >= 0) { g_tcp_close(self->sck); } self->sck = f_alloc_socket(); if (self->sck < 0) { error = 1; break; } /* Try to connect asynchronously */ g_tcp_set_non_blocking(self->sck); error = f_connect(self->sck, server, port); if (error == 0) { /* Connect was immediately successful */ break; } if (g_tcp_last_error_would_block(self->sck)) { /* Async connect is in progress */ if (poll_for_async_connect(self, start_time, timeout) == 0) { /* Async connect was successful */ error = 0; break; } /* No need to wait any more before the next connect attempt */ ms_before_next_connect = 0; } else { /* Connect failed immediately. Wait a bit before the next * attempt */ ms_before_next_connect = CONNECT_DELAY_ON_FAIL_MS; } /* Have we reached the total timeout yet? */ int ms_left = timeout - (g_time3() - start_time); if (ms_left <= 0) { error = 1; break; } /* Sleep a bit before trying again */ if (ms_before_next_connect > 0) { if (ms_before_next_connect > ms_left) { ms_before_next_connect = ms_left; } g_sleep(ms_before_next_connect); } } if (error != 0) { if (self->sck >= 0) { g_tcp_close(self->sck); self->sck = -1; } self->status = TRANS_STATUS_DOWN; } else { self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_CLIENT; /* client */ } return error; } /*****************************************************************************/ /** * @return 0 on success, 1 on failure */ int trans_listen_address(struct trans *self, const char *port, const char *address) { if (self->sck >= 0) { g_tcp_close(self->sck); } if (self->mode == TRANS_MODE_TCP) /* tcp */ { self->sck = g_tcp_socket(); if (self->sck < 0) { return 1; } g_tcp_set_non_blocking(self->sck); if (g_tcp_bind_address(self->sck, port, address) == 0) { if (g_tcp_listen(self->sck) == 0) { self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_LISTENER; /* listener */ return 0; } } } else if (self->mode == TRANS_MODE_UNIX) /* unix socket */ { g_free(self->listen_filename); self->listen_filename = 0; g_file_delete(port); self->sck = g_tcp_local_socket(); if (self->sck < 0) { return 1; } g_tcp_set_non_blocking(self->sck); if (g_tcp_local_bind(self->sck, port) == 0) { self->listen_filename = g_strdup(port); if (g_tcp_listen(self->sck) == 0) { g_chmod_hex(port, 0x0660); self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_LISTENER; /* listener */ return 0; } } } else if (self->mode == TRANS_MODE_VSOCK) /* vsock socket */ { self->sck = g_sck_vsock_socket(); if (self->sck < 0) { return 1; } g_tcp_set_non_blocking(self->sck); if (g_sck_vsock_bind_address(self->sck, port, address) == 0) { if (g_tcp_listen(self->sck) == 0) { self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_LISTENER; /* listener */ return 0; } } } else if (self->mode == TRANS_MODE_TCP4) /* tcp4 */ { self->sck = g_tcp4_socket(); if (self->sck < 0) { return 1; } g_tcp_set_non_blocking(self->sck); if (g_tcp4_bind_address(self->sck, port, address) == 0) { if (g_tcp_listen(self->sck) == 0) { self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_LISTENER; /* listener */ return 0; } } } else if (self->mode == TRANS_MODE_TCP6) /* tcp6 */ { self->sck = g_tcp6_socket(); if (self->sck < 0) { return 1; } g_tcp_set_non_blocking(self->sck); if (g_tcp6_bind_address(self->sck, port, address) == 0) { if (g_tcp_listen(self->sck) == 0) { self->status = TRANS_STATUS_UP; /* ok */ self->type1 = TRANS_TYPE_LISTENER; /* listener */ return 0; } } } return 1; } /*****************************************************************************/ int trans_listen(struct trans *self, const char *port) { return trans_listen_address(self, port, "0.0.0.0"); } /*****************************************************************************/ struct stream * trans_get_in_s(struct trans *self) { struct stream *rv = (struct stream *) NULL; if (self == NULL) { rv = (struct stream *) NULL; } else { rv = self->in_s; } return rv; } /*****************************************************************************/ struct stream * trans_get_out_s(struct trans *self, int size) { struct stream *rv = (struct stream *) NULL; if (self == NULL) { rv = (struct stream *) NULL; } else { init_stream(self->out_s, size); rv = self->out_s; } return rv; } /*****************************************************************************/ /* returns error */ int trans_set_tls_mode(struct trans *self, const char *key, const char *cert, long ssl_protocols, const char *tls_ciphers) { self->tls = ssl_tls_create(self, key, cert); if (self->tls == NULL) { LOG(LOG_LEVEL_ERROR, "trans_set_tls_mode: ssl_tls_create malloc error"); return 1; } if (ssl_tls_accept(self->tls, ssl_protocols, tls_ciphers) != 0) { LOG(LOG_LEVEL_ERROR, "trans_set_tls_mode: ssl_tls_accept failed"); return 1; } /* assign tls functions */ self->trans_recv = trans_tls_recv; self->trans_send = trans_tls_send; self->trans_can_recv = trans_tls_can_recv; self->ssl_protocol = ssl_get_version(self->tls); self->cipher_name = ssl_get_cipher_name(self->tls); return 0; } /*****************************************************************************/ /* returns error */ int trans_shutdown_tls_mode(struct trans *self) { if (self->tls != NULL) { return ssl_tls_disconnect(self->tls); } /* assign callback back to tcp cal */ self->trans_recv = trans_tcp_recv; self->trans_send = trans_tcp_send; self->trans_can_recv = trans_tcp_can_recv; return 0; }