diff --git a/common/string_calls.c b/common/string_calls.c index 151a5dc4..d1d7500a 100644 --- a/common/string_calls.c +++ b/common/string_calls.c @@ -891,3 +891,69 @@ g_strnjoin(char *dest, int dest_len, const char *joiner, const char *src[], int return dest; } + +int +g_bitmask_to_str(int bitmask, const struct bitmask_string bitdefs[], + char delim, char *buff, int bufflen) +{ + int rlen = 0; /* Returned length */ + + if (bufflen <= 0) /* Caller error */ + { + rlen = -1; + } + else + { + char *p = buff; + /* Find the last writeable character in the buffer */ + const char *last = buff + (bufflen - 1); + + const struct bitmask_string *b; + + for (b = &bitdefs[0] ; b->mask != 0; ++b) + { + if ((bitmask & b->mask) != 0) + { + if (p > buff) + { + /* Not first item - append separator */ + if (p < last) + { + *p++ = delim; + } + ++rlen; + } + + int slen = g_strlen(b->str); + int copylen = MIN(slen, last - p); + g_memcpy(p, b->str, copylen); + p += copylen; + rlen += slen; + + /* Remove the bit so we can check for undefined bits later*/ + bitmask &= ~b->mask; + } + } + + if (bitmask != 0) + { + /* Bits left which aren't named by the user */ + if (p > buff) + { + if (p < last) + { + *p++ = delim; + } + ++rlen; + } + /* This call will terminate the return buffer */ + rlen += g_snprintf(p, last - p + 1, "0x%x", bitmask); + } + else + { + *p = '\0'; + } + } + + return rlen; +} diff --git a/common/string_calls.h b/common/string_calls.h index 2767a754..8682febe 100644 --- a/common/string_calls.h +++ b/common/string_calls.h @@ -37,6 +37,21 @@ struct info_string_tag #define INFO_STRING_END_OF_LIST { '\0', NULL } +/** + * Map a bitmask to a string value + * + * + * This structure is used by g_bitmask_to_str() to specify the + * string for each bit in the bitmask + */ +struct bitmask_string +{ + int mask; + const char *str; +}; + +#define BITMASK_STRING_END_OF_LIST { 0, NULL } + /** * Processes a format string for general info * @@ -140,6 +155,25 @@ g_bytes_to_hexdump(const char *src, int len); int g_get_display_num_from_display(const char *display_text); +/** + * Converts a bitmask into a string for output purposes + * + * @param bitmask Bitmask to convert + * @param bitdefs Definitions for strings for bits + * @param delim Delimiter to use between strings + * @param buff Output buff + * @param bufflen Length of buff, including terminator '`\0' + * + * @return Total length excluding terminator which would be written, as + * in snprintf(). Can be used to check for overflow + * + * @note Any undefined bits in the bitmask are appended to the output as + * a hexadecimal constant. + */ +int +g_bitmask_to_str(int bitmask, const struct bitmask_string[], + char delim, char *buff, int bufflen); + int g_strlen(const char *text); const char *g_strchr(const char *text, int c); const char *g_strnchr(const char *text, int c, int len); diff --git a/libxrdp/xrdp_iso.c b/libxrdp/xrdp_iso.c index c00f4568..f2593694 100644 --- a/libxrdp/xrdp_iso.c +++ b/libxrdp/xrdp_iso.c @@ -30,10 +30,51 @@ #include "libxrdp.h" #include "ms-rdpbcgr.h" +#include "string_calls.h" #include "log.h" +/*****************************************************************************/ +/** + * Converts a protocol mask ([MS-RDPBCGR] 2.2.1.1.1 to a string) + * + * @param protocol Protocol mask + * @param buff Output buffer + * @param bufflen total length of buff + * @return As for snprintf() + * + * The string "RDP" is always added to the output, even if other bits + * are set + */ +static int +protocol_mask_to_str(int protocol, char *buff, int bufflen) +{ + char delim = '|'; + static const struct bitmask_string bits[] = + { + { PROTOCOL_SSL, "SSL" }, + { PROTOCOL_HYBRID, "HYBRID" }, + { PROTOCOL_RDSTLS, "RDSTLS" }, + { PROTOCOL_HYBRID_EX, "HYBRID_EX"}, + BITMASK_STRING_END_OF_LIST + }; + + int rlen = g_bitmask_to_str(protocol, bits, delim, buff, bufflen); + + /* Append "RDP" */ + if (rlen == 0) + { + /* String is empty */ + rlen = g_snprintf(buff, bufflen, "RDP"); + } + else if (rlen < bufflen) + { + rlen += g_snprintf(buff + rlen, bufflen - rlen, "%cRDP", delim); + } + + return rlen; +} /*****************************************************************************/ struct xrdp_iso * @@ -64,61 +105,96 @@ xrdp_iso_delete(struct xrdp_iso *self) static int xrdp_iso_negotiate_security(struct xrdp_iso *self) { + char requested_str[64]; + const char *selected_str = ""; + const char *configured_str = ""; + int rv = 0; struct xrdp_client_info *client_info = &(self->mcs_layer->sec_layer->rdp_layer->client_info); - self->selectedProtocol = client_info->security_layer; + /* Can we do TLS/SSL? (basic check) */ + int ssl_capable = g_file_readable(client_info->certificate) && + g_file_readable(client_info->key_file); + /* Work out what's actually configured in xrdp.ini. The + * selection happens later, but we can do some error checking here */ switch (client_info->security_layer) { case PROTOCOL_RDP: + configured_str = "RDP"; break; + case PROTOCOL_SSL: - if (self->requestedProtocol & PROTOCOL_SSL) + /* We *must* use TLS. Check we can offer it, and it's requested */ + if (ssl_capable) { - if (!g_file_readable(client_info->certificate) || - !g_file_readable(client_info->key_file)) + configured_str = "SSL"; + if ((self->requestedProtocol & PROTOCOL_SSL) == 0) { - /* certificate or privkey is not readable */ - LOG(LOG_LEVEL_ERROR, "Cannot accept TLS connections because " - "certificate or private key file is not readable. " - "certificate file: [%s], private key file: [%s]", - client_info->certificate, client_info->key_file); - self->failureCode = SSL_CERT_NOT_ON_SERVER; + LOG(LOG_LEVEL_ERROR, "Server requires TLS for security, " + "but the client did not request TLS."); + self->failureCode = SSL_REQUIRED_BY_SERVER; rv = 1; /* error */ } - else - { - self->selectedProtocol = PROTOCOL_SSL; - } } else { - LOG(LOG_LEVEL_ERROR, "Server requires TLS for security, " - "but the client did not request TLS."); - self->failureCode = SSL_REQUIRED_BY_SERVER; + configured_str = ""; + LOG(LOG_LEVEL_ERROR, "Cannot accept TLS connections because " + "certificate or private key file is not readable. " + "certificate file: [%s], private key file: [%s]", + client_info->certificate, client_info->key_file); + self->failureCode = SSL_CERT_NOT_ON_SERVER; rv = 1; /* error */ } break; case PROTOCOL_HYBRID: case PROTOCOL_HYBRID_EX: default: - if ((self->requestedProtocol & PROTOCOL_SSL) && - g_file_readable(client_info->certificate) && - g_file_readable(client_info->key_file)) + /* We don't yet support CredSSP */ + if (ssl_capable) { - /* that's a patch since we don't support CredSSP for now */ - self->selectedProtocol = PROTOCOL_SSL; + configured_str = "SSL|RDP"; } else { - self->selectedProtocol = PROTOCOL_RDP; + /* + * Tell the user we can't offer TLS, but this isn't fatal */ + configured_str = "RDP"; + LOG(LOG_LEVEL_WARNING, "Cannot accept TLS connections because " + "certificate or private key file is not readable. " + "certificate file: [%s], private key file: [%s]", + client_info->certificate, client_info->key_file); } break; } - LOG(LOG_LEVEL_DEBUG, "Security layer: requested %d, selected %d", - self->requestedProtocol, self->selectedProtocol); + /* Currently the choice comes down to RDP or SSL */ + if (rv != 0) + { + self->selectedProtocol = PROTOCOL_RDP; + selected_str = ""; + } + else if (ssl_capable && (self->requestedProtocol & + client_info->security_layer & + PROTOCOL_SSL) != 0) + { + self->selectedProtocol = PROTOCOL_SSL; + selected_str = "SSL"; + } + else + { + self->selectedProtocol = PROTOCOL_RDP; + selected_str = "RDP"; + } + + protocol_mask_to_str(self->requestedProtocol, + requested_str, sizeof(requested_str)); + + LOG(LOG_LEVEL_INFO, "Security protocol: configured [%s], requested [%s]," + " selected [%s]", + configured_str, requested_str, selected_str); + return rv; } diff --git a/tests/common/test_string_calls.c b/tests/common/test_string_calls.c index 35c63755..54750c8e 100644 --- a/tests/common/test_string_calls.c +++ b/tests/common/test_string_calls.c @@ -213,12 +213,123 @@ START_TEST(test_strnjoin__when_always__then_doesnt_write_beyond_end_of_destinati } END_TEST +/******************************************************************************/ +START_TEST(test_bm2str__no_bits_defined) +{ + int rv; + char buff[64]; + + static const struct bitmask_string bits[] = + { + BITMASK_STRING_END_OF_LIST + }; + + rv = g_bitmask_to_str(0xffff, bits, ',', buff, sizeof(buff)); + + ck_assert_str_eq(buff, "0xffff"); + ck_assert_int_eq(rv, 6); +} +END_TEST + +START_TEST(test_bm2str__all_bits_defined) +{ + int rv; + char buff[64]; + + static const struct bitmask_string bits[] = + { + {1 << 0, "BIT_0"}, + {1 << 1, "BIT_1"}, + {1 << 6, "BIT_6"}, + {1 << 7, "BIT_7"}, + BITMASK_STRING_END_OF_LIST + }; + + int bitmask = 1 << 0 | 1 << 1 | 1 << 6 | 1 << 7; + + rv = g_bitmask_to_str(bitmask, bits, '|', buff, sizeof(buff)); + + ck_assert_str_eq(buff, "BIT_0|BIT_1|BIT_6|BIT_7"); + ck_assert_int_eq(rv, (6 * 4) - 1); +} +END_TEST + +START_TEST(test_bm2str__some_bits_undefined) +{ + int rv; + char buff[64]; + + static const struct bitmask_string bits[] = + { + {1 << 0, "BIT_0"}, + {1 << 1, "BIT_1"}, + {1 << 6, "BIT_6"}, + {1 << 7, "BIT_7"}, + BITMASK_STRING_END_OF_LIST + }; + + int bitmask = 1 << 0 | 1 << 1 | 1 << 16; + + rv = g_bitmask_to_str(bitmask, bits, ':', buff, sizeof(buff)); + + ck_assert_str_eq(buff, "BIT_0:BIT_1:0x10000"); + ck_assert_int_eq(rv, (6 * 2) + 7); +} +END_TEST + +START_TEST(test_bm2str__overflow_all_bits_defined) +{ + int rv; + char buff[10]; + + static const struct bitmask_string bits[] = + { + {1 << 0, "BIT_0"}, + {1 << 1, "BIT_1"}, + {1 << 6, "BIT_6"}, + {1 << 7, "BIT_7"}, + BITMASK_STRING_END_OF_LIST + }; + + int bitmask = 1 << 0 | 1 << 1 | 1 << 6 | 1 << 7; + + rv = g_bitmask_to_str(bitmask, bits, '+', buff, sizeof(buff)); + + ck_assert_str_eq(buff, "BIT_0+BIT"); + ck_assert_int_eq(rv, (4 * 6) - 1); +} +END_TEST + +START_TEST(test_bm2str__overflow_some_bits_undefined) +{ + int rv; + char buff[16]; + + static const struct bitmask_string bits[] = + { + {1 << 0, "BIT_0"}, + {1 << 1, "BIT_1"}, + {1 << 6, "BIT_6"}, + {1 << 7, "BIT_7"}, + BITMASK_STRING_END_OF_LIST + }; + + int bitmask = 1 << 0 | 1 << 1 | 1 << 16; + + rv = g_bitmask_to_str(bitmask, bits, '#', buff, sizeof(buff)); + + ck_assert_str_eq(buff, "BIT_0#BIT_1#0x1"); + ck_assert_int_eq(rv, (6 * 2) + 7); +} +END_TEST + /******************************************************************************/ Suite * make_suite_test_string(void) { Suite *s; TCase *tc_strnjoin; + TCase *tc_bm2str; s = suite_create("String"); @@ -236,5 +347,13 @@ make_suite_test_string(void) tcase_add_test(tc_strnjoin, test_strnjoin__when_destination_has_contents__returns_joined_string_with_content_overwritten); tcase_add_test(tc_strnjoin, test_strnjoin__when_always__then_doesnt_write_beyond_end_of_destination); + tc_bm2str = tcase_create("bm2str"); + suite_add_tcase(s, tc_bm2str); + tcase_add_test(tc_bm2str, test_bm2str__no_bits_defined); + tcase_add_test(tc_bm2str, test_bm2str__all_bits_defined); + tcase_add_test(tc_bm2str, test_bm2str__some_bits_undefined); + tcase_add_test(tc_bm2str, test_bm2str__overflow_all_bits_defined); + tcase_add_test(tc_bm2str, test_bm2str__overflow_some_bits_undefined); + return s; }