From a5fdf9e00637db0de9a970f10761595403186837 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 15:04:41 +0200 Subject: [PATCH 01/13] Refactored gateway HTTP to be self contained. --- libfreerdp/core/gateway/http.c | 342 +++++++++++++++++++-------- libfreerdp/core/gateway/http.h | 84 ++----- libfreerdp/core/gateway/ncacn_http.c | 42 ++-- libfreerdp/core/gateway/rdg.c | 66 +++--- libfreerdp/core/gateway/rpc_client.c | 2 +- 5 files changed, 321 insertions(+), 215 deletions(-) diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index 06097347c..7dc35db39 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -40,6 +40,50 @@ #define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024 +struct _http_context +{ + char* Method; + char* URI; + char* UserAgent; + char* Host; + char* Accept; + char* CacheControl; + char* Connection; + char* Pragma; + char* RdgConnectionId; + char* RdgAuthScheme; +}; + +struct _http_request +{ + char* Method; + char* URI; + char* AuthScheme; + char* AuthParam; + char* Authorization; + size_t ContentLength; + char* Content; + char* TransferEncoding; +}; + +struct _http_response +{ + size_t count; + char** lines; + + long StatusCode; + const char* ReasonPhrase; + + size_t ContentLength; + const char* ContentType; + + size_t BodyLength; + BYTE* BodyContent; + + wListDictionary* Authenticates; + wStream* data; +}; + static char* string_strnstr(const char* str1, const char* str2, size_t slen) { char c, sc; @@ -84,6 +128,9 @@ HttpContext* http_context_new(void) BOOL http_context_set_method(HttpContext* context, const char* Method) { + if (!context || !Method) + return FALSE; + free(context->Method); context->Method = _strdup(Method); @@ -93,8 +140,19 @@ BOOL http_context_set_method(HttpContext* context, const char* Method) return TRUE; } +const char* http_context_get_uri(HttpContext* context) +{ + if (!context) + return NULL; + + return context->URI; +} + BOOL http_context_set_uri(HttpContext* context, const char* URI) { + if (!context || !URI) + return FALSE; + free(context->URI); context->URI = _strdup(URI); @@ -106,6 +164,9 @@ BOOL http_context_set_uri(HttpContext* context, const char* URI) BOOL http_context_set_user_agent(HttpContext* context, const char* UserAgent) { + if (!context || !UserAgent) + return FALSE; + free(context->UserAgent); context->UserAgent = _strdup(UserAgent); @@ -117,6 +178,9 @@ BOOL http_context_set_user_agent(HttpContext* context, const char* UserAgent) BOOL http_context_set_host(HttpContext* context, const char* Host) { + if (!context || !Host) + return FALSE; + free(context->Host); context->Host = _strdup(Host); @@ -128,6 +192,9 @@ BOOL http_context_set_host(HttpContext* context, const char* Host) BOOL http_context_set_accept(HttpContext* context, const char* Accept) { + if (!context || !Accept) + return FALSE; + free(context->Accept); context->Accept = _strdup(Accept); @@ -139,6 +206,9 @@ BOOL http_context_set_accept(HttpContext* context, const char* Accept) BOOL http_context_set_cache_control(HttpContext* context, const char* CacheControl) { + if (!context || !CacheControl) + return FALSE; + free(context->CacheControl); context->CacheControl = _strdup(CacheControl); @@ -150,6 +220,9 @@ BOOL http_context_set_cache_control(HttpContext* context, const char* CacheContr BOOL http_context_set_connection(HttpContext* context, const char* Connection) { + if (!context || !Connection) + return FALSE; + free(context->Connection); context->Connection = _strdup(Connection); @@ -161,6 +234,9 @@ BOOL http_context_set_connection(HttpContext* context, const char* Connection) BOOL http_context_set_pragma(HttpContext* context, const char* Pragma) { + if (!context || !Pragma) + return FALSE; + free(context->Pragma); context->Pragma = _strdup(Pragma); @@ -172,6 +248,9 @@ BOOL http_context_set_pragma(HttpContext* context, const char* Pragma) BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgConnectionId) { + if (!context || !RdgConnectionId) + return FALSE; + free(context->RdgConnectionId); context->RdgConnectionId = _strdup(RdgConnectionId); @@ -183,6 +262,9 @@ BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgCon BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme) { + if (!context || !RdgAuthScheme) + return FALSE; + free(context->RdgAuthScheme); context->RdgAuthScheme = _strdup(RdgAuthScheme); return context->RdgAuthScheme != NULL; @@ -208,6 +290,9 @@ void http_context_free(HttpContext* context) BOOL http_request_set_method(HttpRequest* request, const char* Method) { + if (!request || !Method) + return FALSE; + free(request->Method); request->Method = _strdup(Method); @@ -219,6 +304,9 @@ BOOL http_request_set_method(HttpRequest* request, const char* Method) BOOL http_request_set_uri(HttpRequest* request, const char* URI) { + if (!request || !URI) + return FALSE; + free(request->URI); request->URI = _strdup(URI); @@ -230,6 +318,9 @@ BOOL http_request_set_uri(HttpRequest* request, const char* URI) BOOL http_request_set_auth_scheme(HttpRequest* request, const char* AuthScheme) { + if (!request || !AuthScheme) + return FALSE; + free(request->AuthScheme); request->AuthScheme = _strdup(AuthScheme); @@ -241,6 +332,9 @@ BOOL http_request_set_auth_scheme(HttpRequest* request, const char* AuthScheme) BOOL http_request_set_auth_param(HttpRequest* request, const char* AuthParam) { + if (!request || !AuthParam) + return FALSE; + free(request->AuthParam); request->AuthParam = _strdup(AuthParam); @@ -252,6 +346,9 @@ BOOL http_request_set_auth_param(HttpRequest* request, const char* AuthParam) BOOL http_request_set_transfer_encoding(HttpRequest* request, const char* TransferEncoding) { + if (!request || !TransferEncoding) + return FALSE; + free(request->TransferEncoding); request->TransferEncoding = _strdup(TransferEncoding); @@ -261,149 +358,125 @@ BOOL http_request_set_transfer_encoding(HttpRequest* request, const char* Transf return TRUE; } -static char* http_encode_body_line(const char* param, const char* value) +static BOOL http_encode_print(wStream* s, const char* fmt, ...) { - char* line; - int length; - length = strlen(param) + strlen(value) + 2; - line = (char*) malloc(length + 1); + char* str; + va_list ap; + size_t length, used; - if (!line) - return NULL; + if (!s || !fmt) + return FALSE; - sprintf_s(line, length + 1, "%s: %s", param, value); - return line; + va_start(ap, fmt); + length = vsnprintf(NULL, 0, fmt, ap) + 1; + va_end(ap); + + if (!Stream_EnsureRemainingCapacity(s, length)) + return FALSE; + + str = Stream_Pointer(s); + va_start(ap, fmt); + used = vsnprintf(str, length, fmt, ap); + va_end(ap); + + /* Strip the trailing '\0' from the string. */ + if ((used + 1) != length) + return FALSE; + + Stream_Seek(s, used); + return TRUE; } -static char* http_encode_content_length_line(int ContentLength) +static BOOL http_encode_body_line(wStream* s, const char* param, const char* value) { - const char* key = "Content-Length:"; - char* line; - int length; - char str[32]; - _itoa_s(ContentLength, str, sizeof(str), 10); - length = strlen(key) + strlen(str) + 2; - line = (char*) malloc(length + 1); + if (!s || !param || !value) + return FALSE; - if (!line) - return NULL; - - sprintf_s(line, length + 1, "%s %s", key, str); - return line; + return http_encode_print(s, "%s: %s\r\n", param, value); } -static char* http_encode_header_line(const char* Method, const char* URI) +static BOOL http_encode_content_length_line(wStream* s, size_t ContentLength) { - const char* key = "HTTP/1.1"; - char* line; - int length; - length = strlen(key) + strlen(Method) + strlen(URI) + 2; - line = (char*)malloc(length + 1); - - if (!line) - return NULL; - - sprintf_s(line, length + 1, "%s %s %s", Method, URI, key); - return line; + return http_encode_print(s, "Content-Length: %"PRIdz"\r\n", ContentLength); } -static char* http_encode_authorization_line(const char* AuthScheme, const char* AuthParam) +static BOOL http_encode_header_line(wStream* s, const char* Method, const char* URI) { - const char* key = "Authorization:"; - char* line; - int length; - length = strlen(key) + strlen(AuthScheme) + strlen(AuthParam) + 3; - line = (char*) malloc(length + 1); + if (!s || !Method || !URI) + return FALSE; - if (!line) - return NULL; + return http_encode_print(s, "%s %s HTTP/1.1\r\n", Method, URI); +} - sprintf_s(line, length + 1, "%s %s %s", key, AuthScheme, AuthParam); - return line; +static BOOL http_encode_authorization_line(wStream* s, const char* AuthScheme, + const char* AuthParam) +{ + if (!s || !AuthScheme || !AuthParam) + return FALSE; + + return http_encode_print(s, "Authorization: %s %s\r\n", AuthScheme, AuthParam); } wStream* http_request_write(HttpContext* context, HttpRequest* request) { wStream* s; - int i, count; - char** lines; - int length = 0; - count = 0; - lines = (char**) calloc(32, sizeof(char*)); - if (!lines) + if (!context || !request) return NULL; - lines[count++] = http_encode_header_line(request->Method, request->URI); - lines[count++] = http_encode_body_line("Cache-Control", context->CacheControl); - lines[count++] = http_encode_body_line("Connection", context->Connection); - lines[count++] = http_encode_body_line("Pragma", context->Pragma); - lines[count++] = http_encode_body_line("Accept", context->Accept); - lines[count++] = http_encode_body_line("User-Agent", context->UserAgent); - lines[count++] = http_encode_body_line("Host", context->Host); + s = Stream_New(NULL, 1024); + + if (!s) + return NULL; + + if (!http_encode_header_line(s, request->Method, request->URI) || + !http_encode_body_line(s, "Cache-Control", context->CacheControl) || + !http_encode_body_line(s, "Connection", context->Connection) || + !http_encode_body_line(s, "Pragma", context->Pragma) || + !http_encode_body_line(s, "Accept", context->Accept) || + !http_encode_body_line(s, "User-Agent", context->UserAgent) || + !http_encode_body_line(s, "Host", context->Host)) + goto fail; if (context->RdgConnectionId) - lines[count++] = http_encode_body_line("RDG-Connection-Id", context->RdgConnectionId); + { + if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId)) + goto fail; + } if (context->RdgAuthScheme) - lines[count++] = http_encode_body_line("RDG-Auth-Scheme", context->RdgAuthScheme); + { + if (!http_encode_body_line(s, "RDG-Auth-Scheme", context->RdgAuthScheme)) + goto fail; + } if (request->TransferEncoding) { - lines[count++] = http_encode_body_line("Transfer-Encoding", request->TransferEncoding); + if (!http_encode_body_line(s, "Transfer-Encoding", request->TransferEncoding)) + goto fail; } else { - lines[count++] = http_encode_content_length_line(request->ContentLength); + if (!http_encode_content_length_line(s, request->ContentLength)) + goto fail; } if (request->Authorization) { - lines[count++] = http_encode_body_line("Authorization", request->Authorization); + if (!http_encode_body_line(s, "Authorization", request->Authorization)) + goto fail; } else if (request->AuthScheme && request->AuthParam) { - lines[count++] = http_encode_authorization_line(request->AuthScheme, request->AuthParam); - } - - /* check that everything went well */ - for (i = 0; i < count; i++) - { - if (!lines[i]) - goto out_free; - } - - for (i = 0; i < count; i++) - { - length += (strlen(lines[i]) + 2); /* add +2 for each '\r\n' character */ - } - - length += 2; /* empty line "\r\n" at end of header */ - length += 1; /* null terminator */ - s = Stream_New(NULL, length); - - if (!s) - goto out_free; - - for (i = 0; i < count; i++) - { - Stream_Write(s, lines[i], strlen(lines[i])); - Stream_Write(s, "\r\n", 2); - free(lines[i]); + if (!http_encode_authorization_line(s, request->AuthScheme, request->AuthParam)) + goto fail; } Stream_Write(s, "\r\n", 2); - free(lines); - Stream_Write(s, "\0", 1); /* append null terminator */ - Stream_Rewind(s, 1); /* don't include null terminator in length */ - Stream_SetLength(s, Stream_GetPosition(s)); + Stream_SealLength(s); return s; -out_free: - - for (i = 0; i < count; i++) - free(lines[i]); - - free(lines); +fail: + Stream_Free(s, TRUE); return NULL; } @@ -433,6 +506,9 @@ static BOOL http_response_parse_header_status_line(HttpResponse* response, char* char* status_code; char* reason_phrase; + if (!response) + return FALSE; + if (status_line) separator = strchr(status_line, ' '); @@ -470,6 +546,9 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* { BOOL status = TRUE; + if (!response || !name) + return FALSE; + if (_stricmp(name, "Content-Length") == 0) { unsigned long long val; @@ -520,7 +599,7 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char* authValue = NULL; } - status = ListDictionary_Add(response->Authenticates, (void*)authScheme, (void*)authValue); + status = ListDictionary_Add(response->Authenticates, authScheme, authValue); } return status; @@ -599,14 +678,17 @@ static BOOL http_response_parse_header(HttpResponse* response) return TRUE; } -void http_response_print(HttpResponse* response) +BOOL http_response_print(HttpResponse* response) { - int i; + size_t i; + + if (!response) + return FALSE; for (i = 0; i < response->count; i++) - { WLog_ERR(TAG, "%s", response->lines[i]); - } + + return TRUE; } HttpResponse* http_response_recv(rdpTls* tls) @@ -813,3 +895,55 @@ void http_response_free(HttpResponse* response) Stream_Free(response->data, TRUE); free(response); } + +const char* http_request_get_uri(HttpRequest* request) +{ + if (!request) + return NULL; + + return request->URI; +} + +SSIZE_T http_request_get_content_length(HttpRequest* request) +{ + if (!request) + return -1; + + return request->ContentLength; +} + +BOOL http_request_set_content_length(HttpRequest* request, size_t length) +{ + if (!request) + return FALSE; + + request->ContentLength = length; + return TRUE; +} + +long http_response_get_status_code(HttpResponse* response) +{ + if (!response) + return -1; + + return response->StatusCode; +} + +SSIZE_T http_response_get_body_length(HttpResponse* response) +{ + if (!response) + return -1; + + return response->BodyLength; +} + +const char* http_response_get_auth_token(HttpResponse* respone, const char* method) +{ + if (!respone || !method) + return NULL; + + if (!ListDictionary_Contains(respone->Authenticates, method)) + return NULL; + + return ListDictionary_GetItemValue(respone->Authenticates, method); +} diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index 3e2581840..0427b86c7 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -20,33 +20,21 @@ #ifndef FREERDP_LIB_CORE_GATEWAY_HTTP_H #define FREERDP_LIB_CORE_GATEWAY_HTTP_H -typedef struct _http_context HttpContext; -typedef struct _http_request HttpRequest; -typedef struct _http_response HttpResponse; - -#include -#include -#include - #include #include -struct _http_context -{ - char* Method; - char* URI; - char* UserAgent; - char* Host; - char* Accept; - char* CacheControl; - char* Connection; - char* Pragma; - char* RdgConnectionId; - char* RdgAuthScheme; -}; +#include +#include + +/* HTTP context */ +typedef struct _http_context HttpContext; + +FREERDP_LOCAL HttpContext* http_context_new(void); +FREERDP_LOCAL void http_context_free(HttpContext* context); FREERDP_LOCAL BOOL http_context_set_method(HttpContext* context, const char* Method); +FREERDP_LOCAL const char* http_context_get_uri(HttpContext* context); FREERDP_LOCAL BOOL http_context_set_uri(HttpContext* context, const char* URI); FREERDP_LOCAL BOOL http_context_set_user_agent(HttpContext* context, const char* UserAgent); @@ -65,23 +53,18 @@ FREERDP_LOCAL BOOL http_context_set_rdg_connection_id(HttpContext* context, FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme); -HttpContext* http_context_new(void); -void http_context_free(HttpContext* context); +/* HTTP request */ +typedef struct _http_request HttpRequest; -struct _http_request -{ - char* Method; - char* URI; - char* AuthScheme; - char* AuthParam; - char* Authorization; - int ContentLength; - char* Content; - char* TransferEncoding; -}; +FREERDP_LOCAL HttpRequest* http_request_new(void); +FREERDP_LOCAL void http_request_free(HttpRequest* request); FREERDP_LOCAL BOOL http_request_set_method(HttpRequest* request, const char* Method); +FREERDP_LOCAL SSIZE_T http_request_get_content_length(HttpRequest* request); +FREERDP_LOCAL BOOL http_request_set_content_length(HttpRequest* request, size_t length); + +FREERDP_LOCAL const char* http_request_get_uri(HttpRequest* request); FREERDP_LOCAL BOOL http_request_set_uri(HttpRequest* request, const char* URI); FREERDP_LOCAL BOOL http_request_set_auth_scheme(HttpRequest* request, const char* AuthScheme); @@ -93,32 +76,17 @@ FREERDP_LOCAL BOOL http_request_set_transfer_encoding(HttpRequest* request, FREERDP_LOCAL wStream* http_request_write(HttpContext* context, HttpRequest* request); -FREERDP_LOCAL HttpRequest* http_request_new(void); -FREERDP_LOCAL void http_request_free(HttpRequest* request); - -struct _http_response -{ - size_t count; - char** lines; - - long StatusCode; - const char* ReasonPhrase; - - size_t ContentLength; - const char* ContentType; - - size_t BodyLength; - BYTE* BodyContent; - - wListDictionary* Authenticates; - wStream* data; -}; - -FREERDP_LOCAL void http_response_print(HttpResponse* response); - -FREERDP_LOCAL HttpResponse* http_response_recv(rdpTls* tls); +/* HTTP response */ +typedef struct _http_response HttpResponse; FREERDP_LOCAL HttpResponse* http_response_new(void); FREERDP_LOCAL void http_response_free(HttpResponse* response); +FREERDP_LOCAL BOOL http_response_print(HttpResponse* response); +FREERDP_LOCAL HttpResponse* http_response_recv(rdpTls* tls); + +FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response); +FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response); +FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* respone, const char* method); + #endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */ diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index d8fb8df68..9ced0da99 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -34,25 +34,35 @@ static wStream* rpc_ntlm_http_request(rdpRpc* rpc, HttpContext* http, const char* method, int contentLength, SecBuffer* ntlmToken) { - wStream* s; + wStream* s = NULL; HttpRequest* request; char* base64NtlmToken = NULL; + const char* uri; + + if (!rpc || !http || !method || !ntlmToken) + goto fail; + request = http_request_new(); if (ntlmToken) base64NtlmToken = crypto_base64_encode(ntlmToken->pvBuffer, ntlmToken->cbBuffer); - http_request_set_method(request, method); - request->ContentLength = contentLength; - http_request_set_uri(request, http->URI); + uri = http_context_get_uri(http); + + if (!http_request_set_method(request, method) || + !http_request_set_content_length(request, contentLength) || + !http_request_set_uri(request, uri)) + return NULL; if (base64NtlmToken) { - http_request_set_auth_scheme(request, "NTLM"); - http_request_set_auth_param(request, base64NtlmToken); + if (!http_request_set_auth_scheme(request, "NTLM") || + !http_request_set_auth_param(request, base64NtlmToken)) + goto fail; } s = http_request_write(http, request); +fail: http_request_free(request); free(base64NtlmToken); return s; @@ -85,16 +95,10 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc, RpcInChannel* inChannel int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; rdpNtlm* ntlm = inChannel->ntlm; + token64 = http_response_get_auth_token(response, "NTLM"); - if (ListDictionary_Contains(response->Authenticates, "NTLM")) - { - token64 = ListDictionary_GetItemValue(response->Authenticates, "NTLM"); - - if (!token64) - return -1; - + if (token64) crypto_base64_decode(token64, strlen(token64), &ntlmTokenData, &ntlmTokenLength); - } if (ntlmTokenData && ntlmTokenLength) { @@ -210,16 +214,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc, RpcOutChannel* outChan int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; rdpNtlm* ntlm = outChannel->ntlm; + token64 = http_response_get_auth_token(response, "NTLM"); - if (ListDictionary_Contains(response->Authenticates, "NTLM")) - { - token64 = ListDictionary_GetItemValue(response->Authenticates, "NTLM"); - - if (!token64) - return -1; - + if (token64) crypto_base64_decode(token64, strlen(token64), &ntlmTokenData, &ntlmTokenLength); - } if (ntlmTokenData && ntlmTokenLength) { diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index 77599cf59..11b24286c 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -308,11 +308,11 @@ static BOOL rdg_set_ntlm_auth_header(rdpNtlm* ntlm, HttpRequest* request) if (base64NtlmToken) { - http_request_set_auth_scheme(request, "NTLM"); - http_request_set_auth_param(request, base64NtlmToken); + BOOL rc = http_request_set_auth_scheme(request, "NTLM") && + http_request_set_auth_param(request, base64NtlmToken); free(base64NtlmToken); - if (!request->AuthScheme || !request->AuthParam) + if (!rc) return FALSE; } @@ -324,16 +324,19 @@ static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method, { wStream* s = NULL; HttpRequest* request = NULL; - assert(method != NULL); + const char* uri; + + if (!rdg || !method || !transferEncoding) + return NULL; + + uri = http_context_get_uri(rdg->http); request = http_request_new(); if (!request) return NULL; - http_request_set_method(request, method); - http_request_set_uri(request, rdg->http->URI); - - if (!request->Method || !request->URI) + if (!http_request_set_method(request, method) || + !http_request_set_uri(request, uri)) goto out; if (rdg->ntlm) @@ -362,15 +365,21 @@ static BOOL rdg_handle_ntlm_challenge(rdpNtlm* ntlm, HttpResponse* response) char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; + long StatusCode; - if (response->StatusCode != HTTP_STATUS_DENIED) + if (!ntlm || !response) + return FALSE; + + StatusCode = http_response_get_status_code(response); + + if (StatusCode != HTTP_STATUS_DENIED) { - WLog_DBG(TAG, "Unexpected NTLM challenge HTTP status: %d", - response->StatusCode); + WLog_DBG(TAG, "Unexpected NTLM challenge HTTP status: %ld", + StatusCode); return FALSE; } - token64 = ListDictionary_GetItemValue(response->Authenticates, "NTLM"); + token64 = http_response_get_auth_token(response, "NTLM"); if (!token64) return FALSE; @@ -733,6 +742,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, HttpResponse* response = NULL; int statusCode; int bodyLength; + long StatusCode; if (!rdg_tls_connect(rdg, tls, peerAddress, timeout)) return FALSE; @@ -750,7 +760,9 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, if (!response) return FALSE; - if (response->StatusCode == HTTP_STATUS_NOT_FOUND) + StatusCode = http_response_get_status_code(response); + + if (StatusCode == HTTP_STATUS_NOT_FOUND) { WLog_INFO(TAG, "RD Gateway does not support HTTP transport."); @@ -779,8 +791,8 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, if (!response) return FALSE; - statusCode = response->StatusCode; - bodyLength = response->BodyLength; + statusCode = http_response_get_status_code(response); + bodyLength = http_response_get_body_length(response); http_response_free(response); WLog_DBG(TAG, "%s authorization result: %d", method, statusCode); @@ -1285,18 +1297,14 @@ rdpRdg* rdg_new(rdpTransport* transport) if (!rdg->http) goto rdg_alloc_error; - http_context_set_uri(rdg->http, "/remoteDesktopGateway/"); - http_context_set_accept(rdg->http, "*/*"); - http_context_set_cache_control(rdg->http, "no-cache"); - http_context_set_pragma(rdg->http, "no-cache"); - http_context_set_connection(rdg->http, "Keep-Alive"); - http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0"); - http_context_set_host(rdg->http, rdg->settings->GatewayHostname); - http_context_set_rdg_connection_id(rdg->http, bracedUuid); - - if (!rdg->http->URI || !rdg->http->Accept || !rdg->http->CacheControl || - !rdg->http->Pragma || !rdg->http->Connection || !rdg->http->UserAgent - || !rdg->http->Host || !rdg->http->RdgConnectionId) + if (!http_context_set_uri(rdg->http, "/remoteDesktopGateway/") || + !http_context_set_accept(rdg->http, "*/*") || + !http_context_set_cache_control(rdg->http, "no-cache") || + !http_context_set_pragma(rdg->http, "no-cache") || + !http_context_set_connection(rdg->http, "Keep-Alive") || + !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") || + !http_context_set_host(rdg->http, rdg->settings->GatewayHostname) || + !http_context_set_rdg_connection_id(rdg->http, bracedUuid)) { goto rdg_alloc_error; } @@ -1306,9 +1314,7 @@ rdpRdg* rdg_new(rdpTransport* transport) switch (rdg->extAuth) { case HTTP_EXTENDED_AUTH_PAA: - http_context_set_rdg_auth_scheme(rdg->http, "PAA"); - - if (!rdg->http->RdgAuthScheme) + if (!http_context_set_rdg_auth_scheme(rdg->http, "PAA")) goto rdg_alloc_error; break; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index 0bee974d5..b6614b83e 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -528,7 +528,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) if (!response) return -1; - statusCode = response->StatusCode; + statusCode = http_response_get_status_code(response); if (statusCode != HTTP_STATUS_OK) { From d748adbf1441b44f4247a52c2a9da6a12701d0e9 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 15:19:41 +0200 Subject: [PATCH 02/13] Refactored gateway ncacn HTTP to be self contained. --- libfreerdp/core/gateway/ncacn_http.c | 104 ++++++++++++++++++--------- libfreerdp/core/gateway/ncacn_http.h | 26 +++---- libfreerdp/core/gateway/rpc.c | 12 ++-- libfreerdp/core/gateway/rpc_client.c | 22 +++--- 4 files changed, 95 insertions(+), 69 deletions(-) diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index 9ced0da99..a07ed972e 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -31,15 +31,15 @@ #define TAG FREERDP_TAG("core.gateway.ntlm") -static wStream* rpc_ntlm_http_request(rdpRpc* rpc, HttpContext* http, const char* method, +static wStream* rpc_ntlm_http_request(HttpContext* http, const char* method, int contentLength, SecBuffer* ntlmToken) { wStream* s = NULL; - HttpRequest* request; + HttpRequest* request = NULL; char* base64NtlmToken = NULL; const char* uri; - if (!rpc || !http || !method || !ntlmToken) + if (!http || !method || !ntlmToken) goto fail; request = http_request_new(); @@ -68,17 +68,23 @@ fail: return s; } -int rpc_ncacn_http_send_in_channel_request(rdpRpc* rpc, RpcInChannel* inChannel) +BOOL rpc_ncacn_http_send_in_channel_request(RpcInChannel* inChannel) { wStream* s; int status; int contentLength; BOOL continueNeeded; - rdpNtlm* ntlm = inChannel->ntlm; - HttpContext* http = inChannel->http; + rdpNtlm* ntlm; + HttpContext* http; + + if (!inChannel || !inChannel->ntlm || !inChannel->http) + return FALSE; + + ntlm = inChannel->ntlm; + http = inChannel->http; continueNeeded = ntlm_authenticate(ntlm); contentLength = (continueNeeded) ? 0 : 0x40000000; - s = rpc_ntlm_http_request(rpc, http, "RPC_IN_DATA", contentLength, &ntlm->outputBuffer[0]); + s = rpc_ntlm_http_request(http, "RPC_IN_DATA", contentLength, &ntlm->outputBuffer[0]); if (!s) return -1; @@ -88,13 +94,18 @@ int rpc_ncacn_http_send_in_channel_request(rdpRpc* rpc, RpcInChannel* inChannel) return (status > 0) ? 1 : -1; } -int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc, RpcInChannel* inChannel, +BOOL rpc_ncacn_http_recv_in_channel_response(RpcInChannel* inChannel, HttpResponse* response) { char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; - rdpNtlm* ntlm = inChannel->ntlm; + rdpNtlm* ntlm; + + if (!inChannel || !response || !inChannel->ntlm) + return FALSE; + + ntlm = inChannel->ntlm; token64 = http_response_get_auth_token(response, "NTLM"); if (token64) @@ -106,16 +117,26 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc, RpcInChannel* inChannel ntlm->inputBuffer[0].cbBuffer = ntlmTokenLength; } - return 1; + return TRUE; } -int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) +BOOL rpc_ncacn_http_ntlm_init(rdpContext* context, RpcChannel* channel) { - rdpTls* tls = channel->tls; - rdpNtlm* ntlm = channel->ntlm; - rdpContext* context = rpc->context; - rdpSettings* settings = rpc->settings; - freerdp* instance = context->instance; + rdpTls* tls; + rdpNtlm* ntlm; + rdpSettings* settings; + freerdp* instance; + + if (!context || !channel) + return FALSE; + + tls = channel->tls; + ntlm = channel->ntlm; + settings = context->settings; + instance = context->instance; + + if (!tls || !ntlm || !instance || !settings) + return FALSE; if (!settings->GatewayPassword || !settings->GatewayUsername || !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername)) @@ -128,7 +149,7 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) if (!proceed) { freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); - return 0; + return TRUE; } if (settings->GatewayUseSameCredentials) @@ -138,7 +159,7 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) free(settings->Username); if (!(settings->Username = _strdup(settings->GatewayUsername))) - return -1; + return FALSE; } if (settings->GatewayDomain) @@ -146,7 +167,7 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) free(settings->Domain); if (!(settings->Domain = _strdup(settings->GatewayDomain))) - return -1; + return FALSE; } if (settings->GatewayPassword) @@ -154,7 +175,7 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) free(settings->Password); if (!(settings->Password = _strdup(settings->GatewayPassword))) - return -1; + return FALSE; } } } @@ -163,33 +184,43 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel) if (!ntlm_client_init(ntlm, TRUE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, tls->Bindings)) { - return 0; + return TRUE; } if (!ntlm_client_make_spn(ntlm, _T("HTTP"), settings->GatewayHostname)) { - return 0; + return TRUE; } - return 1; + return TRUE; } -void rpc_ncacn_http_ntlm_uninit(rdpRpc* rpc, RpcChannel* channel) +void rpc_ncacn_http_ntlm_uninit(RpcChannel* channel) { + if (!channel) + return; + ntlm_client_uninit(channel->ntlm); ntlm_free(channel->ntlm); channel->ntlm = NULL; } -int rpc_ncacn_http_send_out_channel_request(rdpRpc* rpc, RpcOutChannel* outChannel, +BOOL rpc_ncacn_http_send_out_channel_request(RpcOutChannel* outChannel, BOOL replacement) { + BOOL rc = TRUE; wStream* s; int status; int contentLength; BOOL continueNeeded; - rdpNtlm* ntlm = outChannel->ntlm; - HttpContext* http = outChannel->http; + rdpNtlm* ntlm; + HttpContext* http; + + if (!outChannel || !outChannel->ntlm || !outChannel->http) + return FALSE; + + ntlm = outChannel->ntlm; + http = outChannel->http; continueNeeded = ntlm_authenticate(ntlm); if (!replacement) @@ -197,23 +228,30 @@ int rpc_ncacn_http_send_out_channel_request(rdpRpc* rpc, RpcOutChannel* outChann else contentLength = (continueNeeded) ? 0 : 120; - s = rpc_ntlm_http_request(rpc, http, "RPC_OUT_DATA", contentLength, &ntlm->outputBuffer[0]); + s = rpc_ntlm_http_request(http, "RPC_OUT_DATA", contentLength, &ntlm->outputBuffer[0]); if (!s) return -1; - status = rpc_out_channel_write(outChannel, Stream_Buffer(s), Stream_Length(s)); + if (rpc_out_channel_write(outChannel, Stream_Buffer(s), Stream_Length(s)) < 0) + rc = FALSE; + Stream_Free(s, TRUE); - return (status > 0) ? 1 : -1; + return rc; } -int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc, RpcOutChannel* outChannel, +BOOL rpc_ncacn_http_recv_out_channel_response(RpcOutChannel* outChannel, HttpResponse* response) { char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; - rdpNtlm* ntlm = outChannel->ntlm; + rdpNtlm* ntlm; + + if (!outChannel || !response || !outChannel->ntlm) + return FALSE; + + ntlm = outChannel->ntlm; token64 = http_response_get_auth_token(response, "NTLM"); if (token64) @@ -225,5 +263,5 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc, RpcOutChannel* outChan ntlm->inputBuffer[0].cbBuffer = ntlmTokenLength; } - return 1; + return TRUE; } diff --git a/libfreerdp/core/gateway/ncacn_http.h b/libfreerdp/core/gateway/ncacn_http.h index 228047646..59cc956a9 100644 --- a/libfreerdp/core/gateway/ncacn_http.h +++ b/libfreerdp/core/gateway/ncacn_http.h @@ -20,29 +20,21 @@ #ifndef FREERDP_LIB_CORE_GATEWAY_NCACN_HTTP_H #define FREERDP_LIB_CORE_GATEWAY_NCACN_HTTP_H -#include -#include - -#include -#include #include -#include - #include "rpc.h" #include "http.h" -FREERDP_LOCAL int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, RpcChannel* channel); -FREERDP_LOCAL void rpc_ncacn_http_ntlm_uninit(rdpRpc* rpc, RpcChannel* channel); +FREERDP_LOCAL BOOL rpc_ncacn_http_ntlm_init(rdpContext* context, RpcChannel* channel); +FREERDP_LOCAL void rpc_ncacn_http_ntlm_uninit(RpcChannel* channel); -FREERDP_LOCAL int rpc_ncacn_http_send_in_channel_request(rdpRpc* rpc, - RpcInChannel* inChannel); -FREERDP_LOCAL int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc, - RpcInChannel* inChannel, HttpResponse* response); +FREERDP_LOCAL BOOL rpc_ncacn_http_send_in_channel_request(RpcInChannel* inChannel); +FREERDP_LOCAL BOOL rpc_ncacn_http_recv_in_channel_response(RpcInChannel* inChannel, + HttpResponse* response); -FREERDP_LOCAL int rpc_ncacn_http_send_out_channel_request(rdpRpc* rpc, - RpcOutChannel* outChannel, BOOL replacement); -FREERDP_LOCAL int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc, - RpcOutChannel* outChannel, HttpResponse* response); +FREERDP_LOCAL BOOL rpc_ncacn_http_send_out_channel_request(RpcOutChannel* outChannel, + BOOL replacement); +FREERDP_LOCAL BOOL rpc_ncacn_http_recv_out_channel_response(RpcOutChannel* outChannel, + HttpResponse* response); #endif /* FREERDP_LIB_CORE_GATEWAY_NCACN_HTTP_H */ diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 6b770ff9d..107077450 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -791,12 +791,12 @@ static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED); - if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) inChannel) < 0) + if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) inChannel)) return -1; /* Send IN Channel Request */ - if (rpc_ncacn_http_send_in_channel_request(rpc, inChannel) < 0) + if (!rpc_ncacn_http_send_in_channel_request(inChannel)) { WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure"); return -1; @@ -817,12 +817,12 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); - if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) outChannel) < 0) + if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) outChannel)) return FALSE; /* Send OUT Channel Request */ - if (rpc_ncacn_http_send_out_channel_request(rpc, outChannel, FALSE) < 0) + if (!rpc_ncacn_http_send_out_channel_request(outChannel, FALSE)) { WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return FALSE; @@ -843,12 +843,12 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); - if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) outChannel) < 0) + if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) outChannel)) return FALSE; /* Send OUT Channel Request */ - if (rpc_ncacn_http_send_out_channel_request(rpc, outChannel, TRUE) < 0) + if (!rpc_ncacn_http_send_out_channel_request(outChannel, TRUE)) { WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return FALSE; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index b6614b83e..634c089fb 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -474,7 +474,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) if (outChannel->State == CLIENT_OUT_CHANNEL_STATE_SECURITY) { /* Receive OUT Channel Response */ - if (rpc_ncacn_http_recv_out_channel_response(rpc, outChannel, response) < 0) + if (!rpc_ncacn_http_recv_out_channel_response(outChannel, response)) { http_response_free(response); WLog_ERR(TAG, "rpc_ncacn_http_recv_out_channel_response failure"); @@ -483,14 +483,14 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) /* Send OUT Channel Request */ - if (rpc_ncacn_http_send_out_channel_request(rpc, outChannel, FALSE) < 0) + if (!rpc_ncacn_http_send_out_channel_request(outChannel, FALSE)) { http_response_free(response); WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return -1; } - rpc_ncacn_http_ntlm_uninit(rpc, (RpcChannel*)outChannel); + rpc_ncacn_http_ntlm_uninit((RpcChannel*)outChannel); rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_NEGOTIATED); @@ -653,15 +653,11 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc) { if (nextOutChannel->State == CLIENT_OUT_CHANNEL_STATE_SECURITY) { - status = rpc_ncacn_http_recv_out_channel_response(rpc, nextOutChannel, response); - - if (status >= 0) + if (rpc_ncacn_http_recv_out_channel_response(nextOutChannel, response)) { - status = rpc_ncacn_http_send_out_channel_request(rpc, nextOutChannel, TRUE); - - if (status >= 0) + if (rpc_ncacn_http_send_out_channel_request(nextOutChannel, TRUE)) { - rpc_ncacn_http_ntlm_uninit(rpc, (RpcChannel*) nextOutChannel); + rpc_ncacn_http_ntlm_uninit((RpcChannel*) nextOutChannel); status = rts_send_OUT_R1_A3_pdu(rpc); if (status >= 0) @@ -738,7 +734,7 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) if (inChannel->State == CLIENT_IN_CHANNEL_STATE_SECURITY) { - if (rpc_ncacn_http_recv_in_channel_response(rpc, inChannel, response) < 0) + if (!rpc_ncacn_http_recv_in_channel_response(inChannel, response)) { WLog_ERR(TAG, "rpc_ncacn_http_recv_in_channel_response failure"); http_response_free(response); @@ -747,14 +743,14 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) /* Send IN Channel Request */ - if (rpc_ncacn_http_send_in_channel_request(rpc, inChannel) < 0) + if (!rpc_ncacn_http_send_in_channel_request(inChannel)) { WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure"); http_response_free(response); return -1; } - rpc_ncacn_http_ntlm_uninit(rpc, (RpcChannel*) inChannel); + rpc_ncacn_http_ntlm_uninit((RpcChannel*) inChannel); rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_NEGOTIATED); From 8a677d6cf2f4d40afec89b1dfa48e383f60dfa84 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 16:05:14 +0200 Subject: [PATCH 03/13] Refactored rdg channel structs. --- libfreerdp/core/gateway/http.c | 10 +- libfreerdp/core/gateway/ncacn_http.c | 17 +-- libfreerdp/core/gateway/ncacn_http.h | 8 +- libfreerdp/core/gateway/rdg.c | 2 +- libfreerdp/core/gateway/rpc.c | 218 ++++++++++----------------- libfreerdp/core/gateway/rpc.h | 38 +++-- libfreerdp/core/gateway/rpc_client.c | 63 ++++---- libfreerdp/core/gateway/rts.c | 28 ++-- libfreerdp/core/gateway/tsg.c | 28 ++-- 9 files changed, 167 insertions(+), 245 deletions(-) diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index 7dc35db39..2d4e93dc8 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -374,7 +374,7 @@ static BOOL http_encode_print(wStream* s, const char* fmt, ...) if (!Stream_EnsureRemainingCapacity(s, length)) return FALSE; - str = Stream_Pointer(s); + str = (char*)Stream_Pointer(s); va_start(ap, fmt); used = vsnprintf(str, length, fmt, ap); va_end(ap); @@ -724,7 +724,7 @@ HttpResponse* http_response_recv(rdpTls* tls) #ifdef HAVE_VALGRIND_MEMCHECK_H VALGRIND_MAKE_MEM_DEFINED(Stream_Pointer(response->data), status); #endif - Stream_Seek(response->data, status); + Stream_Seek(response->data, (size_t)status); if (Stream_GetRemainingLength(response->data) < 1024) { @@ -831,7 +831,7 @@ HttpResponse* http_response_recv(rdpTls* tls) continue; } - Stream_Seek(response->data, status); + Stream_Seek(response->data, (size_t)status); response->BodyLength += status; if (response->BodyLength > RESPONSE_SIZE_LIMIT) @@ -909,7 +909,7 @@ SSIZE_T http_request_get_content_length(HttpRequest* request) if (!request) return -1; - return request->ContentLength; + return (SSIZE_T)request->ContentLength; } BOOL http_request_set_content_length(HttpRequest* request, size_t length) @@ -934,7 +934,7 @@ SSIZE_T http_response_get_body_length(HttpResponse* response) if (!response) return -1; - return response->BodyLength; + return (SSIZE_T)response->BodyLength; } const char* http_response_get_auth_token(HttpResponse* respone, const char* method) diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index a07ed972e..c9d831269 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -68,7 +68,7 @@ fail: return s; } -BOOL rpc_ncacn_http_send_in_channel_request(RpcInChannel* inChannel) +BOOL rpc_ncacn_http_send_in_channel_request(RpcChannel* inChannel) { wStream* s; int status; @@ -89,15 +89,15 @@ BOOL rpc_ncacn_http_send_in_channel_request(RpcInChannel* inChannel) if (!s) return -1; - status = rpc_in_channel_write(inChannel, Stream_Buffer(s), Stream_Length(s)); + status = rpc_channel_write(inChannel, Stream_Buffer(s), Stream_Length(s)); Stream_Free(s, TRUE); return (status > 0) ? 1 : -1; } -BOOL rpc_ncacn_http_recv_in_channel_response(RpcInChannel* inChannel, +BOOL rpc_ncacn_http_recv_in_channel_response(RpcChannel* inChannel, HttpResponse* response) { - char* token64 = NULL; + const char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; rdpNtlm* ntlm; @@ -205,12 +205,11 @@ void rpc_ncacn_http_ntlm_uninit(RpcChannel* channel) channel->ntlm = NULL; } -BOOL rpc_ncacn_http_send_out_channel_request(RpcOutChannel* outChannel, +BOOL rpc_ncacn_http_send_out_channel_request(RpcChannel* outChannel, BOOL replacement) { BOOL rc = TRUE; wStream* s; - int status; int contentLength; BOOL continueNeeded; rdpNtlm* ntlm; @@ -233,17 +232,17 @@ BOOL rpc_ncacn_http_send_out_channel_request(RpcOutChannel* outChannel, if (!s) return -1; - if (rpc_out_channel_write(outChannel, Stream_Buffer(s), Stream_Length(s)) < 0) + if (rpc_channel_write(outChannel, Stream_Buffer(s), Stream_Length(s)) < 0) rc = FALSE; Stream_Free(s, TRUE); return rc; } -BOOL rpc_ncacn_http_recv_out_channel_response(RpcOutChannel* outChannel, +BOOL rpc_ncacn_http_recv_out_channel_response(RpcChannel* outChannel, HttpResponse* response) { - char* token64 = NULL; + const char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; rdpNtlm* ntlm; diff --git a/libfreerdp/core/gateway/ncacn_http.h b/libfreerdp/core/gateway/ncacn_http.h index 59cc956a9..9736d5983 100644 --- a/libfreerdp/core/gateway/ncacn_http.h +++ b/libfreerdp/core/gateway/ncacn_http.h @@ -28,13 +28,13 @@ FREERDP_LOCAL BOOL rpc_ncacn_http_ntlm_init(rdpContext* context, RpcChannel* channel); FREERDP_LOCAL void rpc_ncacn_http_ntlm_uninit(RpcChannel* channel); -FREERDP_LOCAL BOOL rpc_ncacn_http_send_in_channel_request(RpcInChannel* inChannel); -FREERDP_LOCAL BOOL rpc_ncacn_http_recv_in_channel_response(RpcInChannel* inChannel, +FREERDP_LOCAL BOOL rpc_ncacn_http_send_in_channel_request(RpcChannel* inChannel); +FREERDP_LOCAL BOOL rpc_ncacn_http_recv_in_channel_response(RpcChannel* inChannel, HttpResponse* response); -FREERDP_LOCAL BOOL rpc_ncacn_http_send_out_channel_request(RpcOutChannel* outChannel, +FREERDP_LOCAL BOOL rpc_ncacn_http_send_out_channel_request(RpcChannel* outChannel, BOOL replacement); -FREERDP_LOCAL BOOL rpc_ncacn_http_recv_out_channel_response(RpcOutChannel* outChannel, +FREERDP_LOCAL BOOL rpc_ncacn_http_recv_out_channel_response(RpcChannel* outChannel, HttpResponse* response); #endif /* FREERDP_LIB_CORE_GATEWAY_NCACN_HTTP_H */ diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index 11b24286c..a60c1ee2b 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -362,7 +362,7 @@ out: static BOOL rdg_handle_ntlm_challenge(rdpNtlm* ntlm, HttpResponse* response) { - char* token64 = NULL; + const char* token64 = NULL; int ntlmTokenLength = 0; BYTE* ntlmTokenData = NULL; long StatusCode; diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 107077450..93fd2465e 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -327,42 +327,43 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l return TRUE; } -int rpc_out_channel_read(RpcOutChannel* outChannel, BYTE* data, int length) +SSIZE_T rpc_channel_read(RpcChannel* channel, wStream* s, size_t length) { int status; - status = BIO_read(outChannel->tls->bio, data, length); + + if (!channel) + return -1; + + status = BIO_read(channel->tls->bio, Stream_Pointer(s), length); if (status > 0) { #ifdef HAVE_VALGRIND_MEMCHECK_H VALGRIND_MAKE_MEM_DEFINED(data, status); #endif + Stream_Seek(s, (size_t)status); return status; } - if (BIO_should_retry(outChannel->tls->bio)) + if (BIO_should_retry(channel->tls->bio)) return 0; return -1; } -int rpc_in_channel_write(RpcInChannel* inChannel, const BYTE* data, int length) +SSIZE_T rpc_channel_write(RpcChannel* channel, const BYTE* data, size_t length) { int status; - status = tls_write_all(inChannel->tls, data, length); + + if (!channel) + return -1; + + status = tls_write_all(channel->tls, data, length); return status; } -int rpc_out_channel_write(RpcOutChannel* outChannel, const BYTE* data, int length) +BOOL rpc_in_channel_transition_to_state(RpcInChannel* inChannel, CLIENT_IN_CHANNEL_STATE state) { - int status; - status = tls_write_all(outChannel->tls, data, length); - return status; -} - -int rpc_in_channel_transition_to_state(RpcInChannel* inChannel, CLIENT_IN_CHANNEL_STATE state) -{ - int status = 1; const char* str = "CLIENT_IN_CHANNEL_STATE_UNKNOWN"; switch (state) @@ -396,67 +397,63 @@ int rpc_in_channel_transition_to_state(RpcInChannel* inChannel, CLIENT_IN_CHANNE break; } + if (!inChannel) + return FALSE; + inChannel->State = state; WLog_DBG(TAG, "%s", str); - return status; + return TRUE; } -static int rpc_in_channel_rpch_init(rdpRpc* rpc, RpcInChannel* inChannel) +static int rpc_channel_rpch_init(rdpRpc* rpc, RpcChannel* channel, const char* inout) { HttpContext* http; - inChannel->ntlm = ntlm_new(); - if (!inChannel->ntlm) + if (!rpc || !channel || !inout) return -1; - inChannel->http = http_context_new(); + channel->ntlm = ntlm_new(); + rts_generate_cookie((BYTE*) &channel->Cookie); + channel->rpc = rpc; - if (!inChannel->http) + if (!channel->ntlm) + return -1; + + channel->http = http_context_new(); + + if (!channel->http) + return -1; + + http = channel->http; + + if (!http_context_set_method(http, inout) || + !http_context_set_uri(http, "/rpc/rpcproxy.dll?localhost:3388") || + !http_context_set_accept(http, "application/rpc") || + !http_context_set_cache_control(http, "no-cache") || + !http_context_set_connection(http, "Keep-Alive") || !http_context_set_user_agent(http, "MSRPC") || + !http_context_set_host(http, rpc->settings->GatewayHostname) || + !http_context_set_pragma(http, + "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " + "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624")) return -1; - http = inChannel->http; - http_context_set_method(http, "RPC_IN_DATA"); - http_context_set_uri(http, "/rpc/rpcproxy.dll?localhost:3388"); - http_context_set_accept(http, "application/rpc"); - http_context_set_cache_control(http, "no-cache"); - http_context_set_connection(http, "Keep-Alive"); - http_context_set_user_agent(http, "MSRPC"); - http_context_set_host(http, rpc->settings->GatewayHostname); - http_context_set_pragma(http, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); return 1; } static int rpc_in_channel_init(rdpRpc* rpc, RpcInChannel* inChannel) { - rts_generate_cookie((BYTE*) &inChannel->Cookie); - inChannel->rpc = rpc; inChannel->State = CLIENT_IN_CHANNEL_STATE_INITIAL; inChannel->BytesSent = 0; inChannel->SenderAvailableWindow = rpc->ReceiveWindow; inChannel->PingOriginator.ConnectionTimeout = 30; inChannel->PingOriginator.KeepAliveInterval = 0; - if (rpc_in_channel_rpch_init(rpc, inChannel) < 0) + if (rpc_channel_rpch_init(rpc, &inChannel->common, "RPC_IN_DATA") < 0) return -1; return 1; } -static void rpc_in_channel_rpch_uninit(RpcInChannel* inChannel) -{ - if (inChannel->ntlm) - { - ntlm_free(inChannel->ntlm); - inChannel->ntlm = NULL; - } - - if (inChannel->http) - { - http_context_free(inChannel->http); - inChannel->http = NULL; - } -} - static RpcInChannel* rpc_in_channel_new(rdpRpc* rpc) { RpcInChannel* inChannel = NULL; @@ -470,25 +467,19 @@ static RpcInChannel* rpc_in_channel_new(rdpRpc* rpc) return inChannel; } -static void rpc_in_channel_free(RpcInChannel* inChannel) +void rpc_channel_free(RpcChannel* channel) { - if (!inChannel) + if (!channel) return; - rpc_in_channel_rpch_uninit(inChannel); - - if (inChannel->tls) - { - tls_free(inChannel->tls); - inChannel->tls = NULL; - } - - free(inChannel); + ntlm_free(channel->ntlm); + http_context_free(channel->http); + tls_free(channel->tls); + free(channel); } -int rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, CLIENT_OUT_CHANNEL_STATE state) +BOOL rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, CLIENT_OUT_CHANNEL_STATE state) { - int status = 1; const char* str = "CLIENT_OUT_CHANNEL_STATE_UNKNOWN"; switch (state) @@ -534,42 +525,16 @@ int rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, CLIENT_OUT_CH break; } + if (!outChannel) + return FALSE; + outChannel->State = state; WLog_DBG(TAG, "%s", str); - return status; -} - -static int rpc_out_channel_rpch_init(rdpRpc* rpc, RpcOutChannel* outChannel) -{ - HttpContext* http; - outChannel->ntlm = ntlm_new(); - - if (!outChannel->ntlm) - return -1; - - outChannel->http = http_context_new(); - - if (!outChannel->http) - return -1; - - http = outChannel->http; - http_context_set_method(http, "RPC_OUT_DATA"); - http_context_set_uri(http, "/rpc/rpcproxy.dll?localhost:3388"); - http_context_set_accept(http, "application/rpc"); - http_context_set_cache_control(http, "no-cache"); - http_context_set_connection(http, "Keep-Alive"); - http_context_set_user_agent(http, "MSRPC"); - http_context_set_host(http, rpc->settings->GatewayHostname); - http_context_set_pragma(http, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " - "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624"); - return 1; + return TRUE; } static int rpc_out_channel_init(rdpRpc* rpc, RpcOutChannel* outChannel) { - rts_generate_cookie((BYTE*) &outChannel->Cookie); - outChannel->rpc = rpc; outChannel->State = CLIENT_OUT_CHANNEL_STATE_INITIAL; outChannel->BytesReceived = 0; outChannel->ReceiverAvailableWindow = rpc->ReceiveWindow; @@ -577,27 +542,12 @@ static int rpc_out_channel_init(rdpRpc* rpc, RpcOutChannel* outChannel) outChannel->ReceiveWindowSize = rpc->ReceiveWindow; outChannel->AvailableWindowAdvertised = rpc->ReceiveWindow; - if (rpc_out_channel_rpch_init(rpc, outChannel) < 0) + if (rpc_channel_rpch_init(rpc, &outChannel->common, "RPC_OUT_DATA") < 0) return -1; return 1; } -static void rpc_out_channel_rpch_uninit(RpcOutChannel* outChannel) -{ - if (outChannel->ntlm) - { - ntlm_free(outChannel->ntlm); - outChannel->ntlm = NULL; - } - - if (outChannel->http) - { - http_context_free(outChannel->http); - outChannel->http = NULL; - } -} - RpcOutChannel* rpc_out_channel_new(rdpRpc* rpc) { RpcOutChannel* outChannel = NULL; @@ -611,26 +561,9 @@ RpcOutChannel* rpc_out_channel_new(rdpRpc* rpc) return outChannel; } -void rpc_out_channel_free(RpcOutChannel* outChannel) -{ - if (!outChannel) - return; - - rpc_out_channel_rpch_uninit(outChannel); - - if (outChannel->tls) - { - tls_free(outChannel->tls); - outChannel->tls = NULL; - } - - free(outChannel); -} - -int rpc_virtual_connection_transition_to_state(rdpRpc* rpc, +BOOL rpc_virtual_connection_transition_to_state(rdpRpc* rpc, RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state) { - int status = 1; const char* str = "VIRTUAL_CONNECTION_STATE_UNKNOWN"; switch (state) @@ -660,9 +593,12 @@ int rpc_virtual_connection_transition_to_state(rdpRpc* rpc, break; } + if (!connection) + return FALSE; + connection->State = state; WLog_DBG(TAG, "%s", str); - return status; + return TRUE; } static RpcVirtualConnection* rpc_virtual_connection_new(rdpRpc* rpc) @@ -699,10 +635,10 @@ static void rpc_virtual_connection_free(RpcVirtualConnection* connection) if (!connection) return; - rpc_in_channel_free(connection->DefaultInChannel); - rpc_in_channel_free(connection->NonDefaultInChannel); - rpc_out_channel_free(connection->DefaultOutChannel); - rpc_out_channel_free(connection->NonDefaultOutChannel); + rpc_channel_free(&connection->DefaultInChannel->common); + rpc_channel_free(&connection->NonDefaultInChannel->common); + rpc_channel_free(&connection->DefaultOutChannel->common); + rpc_channel_free(&connection->NonDefaultOutChannel->common); free(connection); } @@ -782,47 +718,49 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) { - rdpRpc* rpc = inChannel->rpc; + rdpRpc* rpc = inChannel->common.rpc; /* Connect IN Channel */ - if (rpc_channel_tls_connect((RpcChannel*) inChannel, timeout) < 0) + if (rpc_channel_tls_connect(&inChannel->common, timeout) < 0) return -1; rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED); - if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) inChannel)) + if (!rpc_ncacn_http_ntlm_init(rpc->context, &inChannel->common)) return -1; /* Send IN Channel Request */ - if (!rpc_ncacn_http_send_in_channel_request(inChannel)) + if (!rpc_ncacn_http_send_in_channel_request(&inChannel->common)) { WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure"); return -1; } - rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_SECURITY); + if (!rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_SECURITY)) + return -1; + return 1; } static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) { - rdpRpc* rpc = outChannel->rpc; + rdpRpc* rpc = outChannel->common.rpc; /* Connect OUT Channel */ - if (rpc_channel_tls_connect((RpcChannel*) outChannel, timeout) < 0) + if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) return -1; rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); - if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) outChannel)) + if (!rpc_ncacn_http_ntlm_init(rpc->context, &outChannel->common)) return FALSE; /* Send OUT Channel Request */ - if (!rpc_ncacn_http_send_out_channel_request(outChannel, FALSE)) + if (!rpc_ncacn_http_send_out_channel_request(&outChannel->common, FALSE)) { WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return FALSE; @@ -834,7 +772,7 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) { - rdpRpc* rpc = outChannel->rpc; + rdpRpc* rpc = outChannel->common.rpc; /* Connect OUT Channel */ @@ -848,7 +786,7 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) /* Send OUT Channel Request */ - if (!rpc_ncacn_http_send_out_channel_request(outChannel, TRUE)) + if (!rpc_ncacn_http_send_out_channel_request(&outChannel->common, TRUE)) { WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return FALSE; diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index 1d2b4c51e..4aebd178b 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -585,17 +585,14 @@ struct rpc_client_call }; typedef struct rpc_client_call RpcClientCall; -#define RPC_CHANNEL_COMMON() \ - rdpRpc* rpc; \ - BIO* bio; \ - rdpTls* tls; \ - rdpNtlm* ntlm; \ - HttpContext* http; \ - BYTE Cookie[16] - struct rpc_channel { - RPC_CHANNEL_COMMON(); + rdpRpc* rpc; + BIO* bio; + rdpTls* tls; + rdpNtlm* ntlm; + HttpContext* http; + BYTE Cookie[16]; }; typedef struct rpc_channel RpcChannel; @@ -627,7 +624,7 @@ struct rpc_in_channel { /* Sending Channel */ - RPC_CHANNEL_COMMON(); + RpcChannel common; CLIENT_IN_CHANNEL_STATE State; @@ -664,7 +661,7 @@ struct rpc_out_channel { /* Receiving Channel */ - RPC_CHANNEL_COMMON(); + RpcChannel common; CLIENT_OUT_CHANNEL_STATE State; @@ -774,25 +771,24 @@ FREERDP_LOCAL UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad); FREERDP_LOCAL BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length); -FREERDP_LOCAL int rpc_in_channel_write(RpcInChannel* inChannel, - const BYTE* data, int length); +FREERDP_LOCAL SSIZE_T rpc_channel_write(RpcChannel* channel, + const BYTE* data, size_t length); -FREERDP_LOCAL int rpc_out_channel_read(RpcOutChannel* outChannel, BYTE* data, - int length); -FREERDP_LOCAL int rpc_out_channel_write(RpcOutChannel* outChannel, - const BYTE* data, int length); +FREERDP_LOCAL SSIZE_T rpc_channel_read(RpcChannel* channel, wStream* s, + size_t length); + +FREERDP_LOCAL void rpc_channel_free(RpcChannel* channel); FREERDP_LOCAL RpcOutChannel* rpc_out_channel_new(rdpRpc* rpc); FREERDP_LOCAL int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout); -FREERDP_LOCAL void rpc_out_channel_free(RpcOutChannel* outChannel); -FREERDP_LOCAL int rpc_in_channel_transition_to_state(RpcInChannel* inChannel, +FREERDP_LOCAL BOOL rpc_in_channel_transition_to_state(RpcInChannel* inChannel, CLIENT_IN_CHANNEL_STATE state); -FREERDP_LOCAL int rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, +FREERDP_LOCAL BOOL rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, CLIENT_OUT_CHANNEL_STATE state); -FREERDP_LOCAL int rpc_virtual_connection_transition_to_state(rdpRpc* rpc, +FREERDP_LOCAL BOOL rpc_virtual_connection_transition_to_state(rdpRpc* rpc, RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state); FREERDP_LOCAL BOOL rpc_connect(rdpRpc* rpc, int timeout); diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index 634c089fb..3aa6d169f 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -459,14 +459,14 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) RpcVirtualConnection* connection = rpc->VirtualConnection; inChannel = connection->DefaultInChannel; outChannel = connection->DefaultOutChannel; - BIO_get_event(outChannel->tls->bio, &outChannelEvent); + BIO_get_event(outChannel->common.tls->bio, &outChannelEvent); if (outChannel->State < CLIENT_OUT_CHANNEL_STATE_OPENED) { if (WaitForSingleObject(outChannelEvent, 0) != WAIT_OBJECT_0) return 1; - response = http_response_recv(outChannel->tls); + response = http_response_recv(outChannel->common.tls); if (!response) return -1; @@ -474,7 +474,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) if (outChannel->State == CLIENT_OUT_CHANNEL_STATE_SECURITY) { /* Receive OUT Channel Response */ - if (!rpc_ncacn_http_recv_out_channel_response(outChannel, response)) + if (!rpc_ncacn_http_recv_out_channel_response(&outChannel->common, response)) { http_response_free(response); WLog_ERR(TAG, "rpc_ncacn_http_recv_out_channel_response failure"); @@ -483,14 +483,14 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) /* Send OUT Channel Request */ - if (!rpc_ncacn_http_send_out_channel_request(outChannel, FALSE)) + if (!rpc_ncacn_http_send_out_channel_request(&outChannel->common, FALSE)) { http_response_free(response); WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure"); return -1; } - rpc_ncacn_http_ntlm_uninit((RpcChannel*)outChannel); + rpc_ncacn_http_ntlm_uninit(&outChannel->common); rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_NEGOTIATED); @@ -523,7 +523,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) if (WaitForSingleObject(outChannelEvent, 0) != WAIT_OBJECT_0) return 1; - response = http_response_recv(outChannel->tls); + response = http_response_recv(outChannel->common.tls); if (!response) return -1; @@ -560,21 +560,16 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) { while (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH) { - status = rpc_out_channel_read(outChannel, Stream_Pointer(fragment), - RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(fragment)); + status = rpc_channel_read(&outChannel->common, fragment, + RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(fragment)); if (status < 0) return -1; - if (!status) + if (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH) return 0; - - Stream_Seek(fragment, status); } - if (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH) - return status; - header = (rpcconn_common_hdr_t*)Stream_Buffer(fragment); if (header->frag_length > rpc->max_recv_frag) @@ -587,8 +582,8 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) while (Stream_GetPosition(fragment) < header->frag_length) { - status = rpc_out_channel_read(outChannel, Stream_Pointer(fragment), - header->frag_length - Stream_GetPosition(fragment)); + status = rpc_channel_read(&outChannel->common, fragment, + header->frag_length - Stream_GetPosition(fragment)); if (status < 0) { @@ -596,16 +591,10 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) return -1; } - if (!status) + if (Stream_GetPosition(fragment) < header->frag_length) return 0; - - Stream_Seek(fragment, status); } - if (status < 0) - return -1; - - if (Stream_GetPosition(fragment) >= header->frag_length) { /* complete fragment received */ Stream_SealLength(fragment); @@ -618,7 +607,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc) /* channel recycling may update channel pointers */ if (outChannel->State == CLIENT_OUT_CHANNEL_STATE_RECYCLED && connection->NonDefaultOutChannel) { - rpc_out_channel_free(connection->DefaultOutChannel); + rpc_channel_free(&connection->DefaultOutChannel->common); connection->DefaultOutChannel = connection->NonDefaultOutChannel; connection->NonDefaultOutChannel = NULL; rpc_out_channel_transition_to_state(connection->DefaultOutChannel, CLIENT_OUT_CHANNEL_STATE_OPENED); @@ -642,22 +631,22 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc) RpcOutChannel* nextOutChannel; HANDLE nextOutChannelEvent = NULL; nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel; - BIO_get_event(nextOutChannel->tls->bio, &nextOutChannelEvent); + BIO_get_event(nextOutChannel->common.tls->bio, &nextOutChannelEvent); if (WaitForSingleObject(nextOutChannelEvent, 0) != WAIT_OBJECT_0) return 1; - response = http_response_recv(nextOutChannel->tls); + response = http_response_recv(nextOutChannel->common.tls); if (response) { if (nextOutChannel->State == CLIENT_OUT_CHANNEL_STATE_SECURITY) { - if (rpc_ncacn_http_recv_out_channel_response(nextOutChannel, response)) + if (rpc_ncacn_http_recv_out_channel_response(&nextOutChannel->common, response)) { - if (rpc_ncacn_http_send_out_channel_request(nextOutChannel, TRUE)) + if (rpc_ncacn_http_send_out_channel_request(&nextOutChannel->common, TRUE)) { - rpc_ncacn_http_ntlm_uninit((RpcChannel*) nextOutChannel); + rpc_ncacn_http_ntlm_uninit(&nextOutChannel->common); status = rts_send_OUT_R1_A3_pdu(rpc); if (status >= 0) @@ -720,21 +709,21 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) RpcVirtualConnection* connection = rpc->VirtualConnection; inChannel = connection->DefaultInChannel; outChannel = connection->DefaultOutChannel; - BIO_get_event(inChannel->tls->bio, &InChannelEvent); + BIO_get_event(inChannel->common.tls->bio, &InChannelEvent); if (WaitForSingleObject(InChannelEvent, 0) != WAIT_OBJECT_0) return 1; if (inChannel->State < CLIENT_IN_CHANNEL_STATE_OPENED) { - response = http_response_recv(inChannel->tls); + response = http_response_recv(inChannel->common.tls); if (!response) return -1; if (inChannel->State == CLIENT_IN_CHANNEL_STATE_SECURITY) { - if (!rpc_ncacn_http_recv_in_channel_response(inChannel, response)) + if (!rpc_ncacn_http_recv_in_channel_response(&inChannel->common, response)) { WLog_ERR(TAG, "rpc_ncacn_http_recv_in_channel_response failure"); http_response_free(response); @@ -743,14 +732,14 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) /* Send IN Channel Request */ - if (!rpc_ncacn_http_send_in_channel_request(inChannel)) + if (!rpc_ncacn_http_send_in_channel_request(&inChannel->common)) { WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure"); http_response_free(response); return -1; } - rpc_ncacn_http_ntlm_uninit((RpcChannel*) inChannel); + rpc_ncacn_http_ntlm_uninit(&inChannel->common); rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_NEGOTIATED); @@ -779,7 +768,7 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) } else { - response = http_response_recv(inChannel->tls); + response = http_response_recv(inChannel->common.tls); if (!response) return -1; @@ -840,8 +829,8 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length int status; RpcClientCall* clientCall; rpcconn_common_hdr_t* header; - rdpRpc* rpc = inChannel->rpc; - status = rpc_in_channel_write(inChannel, buffer, length); + rdpRpc* rpc = inChannel->common.rpc; + status = rpc_channel_write(&inChannel->common, buffer, length); if (status <= 0) return -1; diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index bcf7c5bdb..a0df68962 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -416,7 +416,7 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc) header.NumberOfCommands = 4; WLog_DBG(TAG, "Sending CONN/A1 RTS PDU"); VirtualConnectionCookie = (BYTE*) & (connection->Cookie); - OUTChannelCookie = (BYTE*) & (outChannel->Cookie); + OUTChannelCookie = (BYTE*) & (outChannel->common.Cookie); ReceiveWindowSize = outChannel->ReceiveWindow; buffer = (BYTE*) malloc(header.frag_length); @@ -430,7 +430,7 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc) rts_cookie_command_write(&buffer[48], OUTChannelCookie); /* OUTChannelCookie (20 bytes) */ rts_receive_window_size_command_write(&buffer[68], ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */ - status = rpc_out_channel_write(outChannel, buffer, header.frag_length); + status = rpc_channel_write(&outChannel->common, buffer, header.frag_length); free(buffer); return (status > 0) ? 1 : -1; } @@ -463,7 +463,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) header.NumberOfCommands = 6; WLog_DBG(TAG, "Sending CONN/B1 RTS PDU"); VirtualConnectionCookie = (BYTE*) & (connection->Cookie); - INChannelCookie = (BYTE*) & (inChannel->Cookie); + INChannelCookie = (BYTE*) & (inChannel->common.Cookie); AssociationGroupId = (BYTE*) & (connection->AssociationGroupId); buffer = (BYTE*) malloc(header.frag_length); @@ -482,7 +482,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) rts_association_group_id_command_write(&buffer[84], AssociationGroupId); /* AssociationGroupId (20 bytes) */ length = header.frag_length; - status = rpc_in_channel_write(inChannel, buffer, length); + status = rpc_channel_write(&inChannel->common, buffer, length); free(buffer); return (status > 0) ? 1 : -1; } @@ -531,7 +531,7 @@ static int rts_send_keep_alive_pdu(rdpRpc* rpc) rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ length = header.frag_length; - status = rpc_in_channel_write(inChannel, buffer, length); + status = rpc_channel_write(&inChannel->common, buffer, length); free(buffer); return (status > 0) ? 1 : -1; } @@ -555,7 +555,7 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) WLog_DBG(TAG, "Sending FlowControlAck RTS PDU"); BytesReceived = outChannel->BytesReceived; AvailableWindow = outChannel->AvailableWindowAdvertised; - ChannelCookie = (BYTE*) & (outChannel->Cookie); + ChannelCookie = (BYTE*) & (outChannel->common.Cookie); outChannel->ReceiverAvailableWindow = outChannel->AvailableWindowAdvertised; buffer = (BYTE*) malloc(header.frag_length); @@ -567,7 +567,7 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) /* FlowControlAck Command (28 bytes) */ rts_flow_control_ack_command_write(&buffer[28], BytesReceived, AvailableWindow, ChannelCookie); length = header.frag_length; - status = rpc_in_channel_write(inChannel, buffer, length); + status = rpc_channel_write(&inChannel->common, buffer, length); free(buffer); return (status > 0) ? 1 : -1; } @@ -642,7 +642,7 @@ static int rts_send_ping_pdu(rdpRpc* rpc) CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ length = header.frag_length; - status = rpc_in_channel_write(inChannel, buffer, length); + status = rpc_channel_write(&inChannel->common, buffer, length); free(buffer); return (status > 0) ? 1 : -1; } @@ -735,7 +735,7 @@ static int rts_send_OUT_R2_A7_pdu(rdpRpc* rpc) header.Flags = RTS_FLAG_OUT_CHANNEL; header.NumberOfCommands = 3; WLog_DBG(TAG, "Sending OUT_R2/A7 RTS PDU"); - SuccessorChannelCookie = (BYTE*) & (nextOutChannel->Cookie); + SuccessorChannelCookie = (BYTE*) & (nextOutChannel->common.Cookie); buffer = (BYTE*) malloc(header.frag_length); if (!buffer) @@ -746,7 +746,7 @@ static int rts_send_OUT_R2_A7_pdu(rdpRpc* rpc) rts_cookie_command_write(&buffer[28], SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */ rts_version_command_write(&buffer[48]); /* Version (8 bytes) */ - status = rpc_in_channel_write(inChannel, buffer, header.frag_length); + status = rpc_channel_write(&inChannel->common, buffer, header.frag_length); free(buffer); return (status > 0) ? 1 : -1; } @@ -769,7 +769,7 @@ static int rts_send_OUT_R2_C1_pdu(rdpRpc* rpc) CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_empty_command_write(&buffer[20]); /* Empty command (4 bytes) */ - status = rpc_out_channel_write(nextOutChannel, buffer, header.frag_length); + status = rpc_channel_write(&nextOutChannel->common, buffer, header.frag_length); free(buffer); return (status > 0) ? 1 : -1; } @@ -792,8 +792,8 @@ int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc) header.NumberOfCommands = 5; WLog_DBG(TAG, "Sending OUT_R1/A3 RTS PDU"); VirtualConnectionCookie = (BYTE*) & (connection->Cookie); - PredecessorChannelCookie = (BYTE*) & (outChannel->Cookie); - SuccessorChannelCookie = (BYTE*) & (nextOutChannel->Cookie); + PredecessorChannelCookie = (BYTE*) & (outChannel->common.Cookie); + SuccessorChannelCookie = (BYTE*) & (nextOutChannel->common.Cookie); ReceiveWindowSize = outChannel->ReceiveWindow; buffer = (BYTE*) malloc(header.frag_length); @@ -810,7 +810,7 @@ int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc) SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */ rts_receive_window_size_command_write(&buffer[88], ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */ - status = rpc_out_channel_write(nextOutChannel, buffer, header.frag_length); + status = rpc_channel_write(&nextOutChannel->common, buffer, header.frag_length); free(buffer); return (status > 0) ? 1 : -1; } diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index e4e5599e7..a3000c7ae 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -1595,44 +1595,44 @@ DWORD tsg_get_event_handles(rdpTsg* tsg, HANDLE* events, DWORD count) else return 0; - if (connection->DefaultInChannel && connection->DefaultInChannel->tls) + if (connection->DefaultInChannel && connection->DefaultInChannel->common.tls) { if (events && (nCount < count)) { - BIO_get_event(connection->DefaultInChannel->tls->bio, &events[nCount]); + BIO_get_event(connection->DefaultInChannel->common.tls->bio, &events[nCount]); nCount++; } else return 0; } - if (connection->NonDefaultInChannel && connection->NonDefaultInChannel->tls) + if (connection->NonDefaultInChannel && connection->NonDefaultInChannel->common.tls) { if (events && (nCount < count)) { - BIO_get_event(connection->NonDefaultInChannel->tls->bio, &events[nCount]); + BIO_get_event(connection->NonDefaultInChannel->common.tls->bio, &events[nCount]); nCount++; } else return 0; } - if (connection->DefaultOutChannel && connection->DefaultOutChannel->tls) + if (connection->DefaultOutChannel && connection->DefaultOutChannel->common.tls) { if (events && (nCount < count)) { - BIO_get_event(connection->DefaultOutChannel->tls->bio, &events[nCount]); + BIO_get_event(connection->DefaultOutChannel->common.tls->bio, &events[nCount]); nCount++; } else return 0; } - if (connection->NonDefaultOutChannel && connection->NonDefaultOutChannel->tls) + if (connection->NonDefaultOutChannel && connection->NonDefaultOutChannel->common.tls) { if (events && (nCount < count)) { - BIO_get_event(connection->NonDefaultOutChannel->tls->bio, &events[nCount]); + BIO_get_event(connection->NonDefaultOutChannel->common.tls->bio, &events[nCount]); nCount++; } else @@ -1928,8 +1928,8 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) if (cmd == BIO_CTRL_FLUSH) { - (void)BIO_flush(inChannel->tls->bio); - (void)BIO_flush(outChannel->tls->bio); + (void)BIO_flush(inChannel->common.tls->bio); + (void)BIO_flush(outChannel->common.tls->bio); status = 1; } else if (cmd == BIO_C_GET_EVENT) @@ -1946,18 +1946,18 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) } else if (cmd == BIO_C_READ_BLOCKED) { - BIO* bio = outChannel->bio; + BIO* bio = outChannel->common.bio; status = BIO_read_blocked(bio); } else if (cmd == BIO_C_WRITE_BLOCKED) { - BIO* bio = inChannel->bio; + BIO* bio = inChannel->common.bio; status = BIO_write_blocked(bio); } else if (cmd == BIO_C_WAIT_READ) { int timeout = (int) arg1; - BIO* bio = outChannel->bio; + BIO* bio = outChannel->common.bio; if (BIO_read_blocked(bio)) return BIO_wait_read(bio, timeout); @@ -1969,7 +1969,7 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) else if (cmd == BIO_C_WAIT_WRITE) { int timeout = (int) arg1; - BIO* bio = inChannel->bio; + BIO* bio = inChannel->common.bio; if (BIO_write_blocked(bio)) status = BIO_wait_write(bio, timeout); From 47ba37fbcb1d67ae9ce917bd54959fead2096181 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 16:08:28 +0200 Subject: [PATCH 04/13] Unified dns resolving of host --- libfreerdp/core/connection.c | 18 ++++------- libfreerdp/core/listener.c | 21 ++++--------- libfreerdp/core/tcp.c | 58 ++++++++++++++++++++---------------- libfreerdp/core/tcp.h | 4 +++ 4 files changed, 47 insertions(+), 54 deletions(-) diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 3f8791686..9248742df 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -409,21 +409,13 @@ static BOOL rdp_client_reconnect_channels(rdpRdp* rdp, BOOL redirect) static BOOL rdp_client_redirect_resolvable(const char* host) { - int status; - struct addrinfo hints = { 0 }; - struct addrinfo* result = NULL; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - status = getaddrinfo(host, NULL, &hints, &result); - errno = 0; + struct addrinfo* result = freerdp_tcp_resolve_host(host, -1, 0); - if (status == 0) - { - freeaddrinfo(result); - return TRUE; - } + if (!result) + return FALSE; - return FALSE; + freeaddrinfo(result); + return TRUE; } static BOOL rdp_client_redirect_try_fqdn(rdpSettings* settings) diff --git a/libfreerdp/core/listener.c b/libfreerdp/core/listener.c index e0ba23973..2288236a7 100644 --- a/libfreerdp/core/listener.c +++ b/libfreerdp/core/listener.c @@ -51,37 +51,26 @@ static BOOL freerdp_listener_open(freerdp_listener* instance, const char* bind_address, UINT16 port) { + int ai_flags = 0; int status; int sockfd; char addr[64]; void* sin_addr; int option_value; - char servname[16]; struct addrinfo* ai; struct addrinfo* res; - struct addrinfo hints = { 0 }; rdpListener* listener = (rdpListener*) instance->listener; #ifdef _WIN32 u_long arg; #endif - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; if (!bind_address) - hints.ai_flags = AI_PASSIVE; + ai_flags = AI_PASSIVE; - sprintf_s(servname, sizeof(servname), "%"PRIu16"", port); - status = getaddrinfo(bind_address, servname, &hints, &res); + res = freerdp_tcp_resolve_host(bind_address, port, ai_flags); - if (status != 0) - { -#ifdef _WIN32 - WLog_ERR("getaddrinfo error: %s", gai_strerrorA(status)); -#else - WLog_ERR(TAG, "getaddrinfo"); -#endif + if (!res) return FALSE; - } for (ai = res; ai && (listener->num_sockfds < 5); ai = ai->ai_next) { @@ -148,7 +137,7 @@ static BOOL freerdp_listener_open(freerdp_listener* instance, const char* bind_a WSAEventSelect(sockfd, listener->events[listener->num_sockfds], FD_READ | FD_ACCEPT | FD_CLOSE); listener->num_sockfds++; - WLog_INFO(TAG, "Listening on %s:%s", addr, servname); + WLog_INFO(TAG, "Listening on %s:%d", addr, port); } freeaddrinfo(res); diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index 7c655cdc7..b425bdb81 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -664,7 +664,7 @@ BIO_METHOD* BIO_s_buffered_socket(void) return bio_methods; } -static char* freerdp_tcp_address_to_string(const struct sockaddr_storage* addr, BOOL* pIPv6) +char* freerdp_tcp_address_to_string(const struct sockaddr_storage* addr, BOOL* pIPv6) { char ipAddress[INET6_ADDRSTRLEN + 1] = { 0 }; struct sockaddr_in6* sockaddr_ipv6 = (struct sockaddr_in6*)addr; @@ -762,16 +762,39 @@ static int freerdp_uds_connect(const char* path) #endif } -static BOOL freerdp_tcp_resolve_hostname(rdpContext* context, const char* hostname) +struct addrinfo* freerdp_tcp_resolve_host(const char* hostname, int port, int ai_flags) { + char* service = NULL; + char port_str[16]; int status; struct addrinfo hints = { 0 }; struct addrinfo* result = NULL; hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; - status = getaddrinfo(hostname, NULL, &hints, &result); + hints.ai_flags = ai_flags; + + if (port >= 0) + { + sprintf_s(port_str, sizeof(port_str) - 1, "%d", port); + service = port_str; + } + + status = getaddrinfo(hostname, service, &hints, &result); if (status) + { + freeaddrinfo(result); + return NULL; + } + + return result; +} + +static BOOL freerdp_tcp_is_hostname_resolvable(rdpContext* context, const char* hostname) +{ + struct addrinfo* result = freerdp_tcp_resolve_host(hostname, -1, 0); + + if (!result) { if (!freerdp_get_last_error(context)) freerdp_set_last_error(context, FREERDP_ERROR_DNS_NAME_NOT_FOUND); @@ -870,13 +893,10 @@ static int freerdp_tcp_connect_multi(rdpContext* context, char** hostnames, SOCKET* sockfds; HANDLE* events; DWORD waitStatus; - char port_str[16]; - struct addrinfo hints; struct addrinfo* addr; struct addrinfo* result; struct addrinfo** addrs; struct addrinfo** results; - sprintf_s(port_str, sizeof(port_str) - 1, "%d", port); sockfds = (SOCKET*) calloc(count, sizeof(SOCKET)); events = (HANDLE*) calloc(count + 1, sizeof(HANDLE)); addrs = (struct addrinfo**) calloc(count, sizeof(struct addrinfo*)); @@ -893,19 +913,15 @@ static int freerdp_tcp_connect_multi(rdpContext* context, char** hostnames, for (index = 0; index < count; index++) { - ZeroMemory(&hints, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; + int port = -1; if (ports) - sprintf_s(port_str, sizeof(port_str) - 1, "%"PRIu32"", ports[index]); + port = ports[index]; - status = getaddrinfo(hostnames[index], port_str, &hints, &result); + result = freerdp_tcp_resolve_host(hostnames[index], port, 0); - if (status) - { + if (!result) continue; - } addr = result; @@ -1089,7 +1105,6 @@ BOOL freerdp_tcp_set_keep_alive_mode(int sockfd) int freerdp_tcp_connect(rdpContext* context, rdpSettings* settings, const char* hostname, int port, int timeout) { - int status; int sockfd; UINT32 optval; socklen_t optlen; @@ -1130,7 +1145,7 @@ int freerdp_tcp_connect(rdpContext* context, rdpSettings* settings, if (!settings->GatewayEnabled) { - if (!freerdp_tcp_resolve_hostname(context, hostname) || settings->RemoteAssistanceMode) + if (!freerdp_tcp_is_hostname_resolvable(context, hostname) || settings->RemoteAssistanceMode) { if (settings->TargetNetAddressCount > 0) { @@ -1146,23 +1161,16 @@ int freerdp_tcp_connect(rdpContext* context, rdpSettings* settings, if (sockfd <= 0) { - char port_str[16]; char* peerAddress; - struct addrinfo hints; struct addrinfo* addr; struct addrinfo* result; - ZeroMemory(&hints, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - sprintf_s(port_str, sizeof(port_str) - 1, "%d", port); - status = getaddrinfo(hostname, port_str, &hints, &result); + result = freerdp_tcp_resolve_host(hostname, port, 0); - if (status) + if (!result) { if (!freerdp_get_last_error(context)) freerdp_set_last_error(context, FREERDP_ERROR_DNS_NAME_NOT_FOUND); - WLog_ERR(TAG, "getaddrinfo: %s", gai_strerror(status)); return -1; } diff --git a/libfreerdp/core/tcp.h b/libfreerdp/core/tcp.h index ad14255b7..7a3416650 100644 --- a/libfreerdp/core/tcp.h +++ b/libfreerdp/core/tcp.h @@ -69,4 +69,8 @@ FREERDP_LOCAL int freerdp_tcp_connect(rdpContext* context, FREERDP_LOCAL char* freerdp_tcp_get_peer_address(SOCKET sockfd); +FREERDP_LOCAL struct addrinfo* freerdp_tcp_resolve_host(const char* hostname, int port, + int ai_flags); +FREERDP_LOCAL char* freerdp_tcp_address_to_string(const struct sockaddr_storage* addr, BOOL* pIPv6); + #endif /* FREERDP_LIB_CORE_TCP_H */ From f5f155b057931f2c577270a291a8c370f1785f30 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 16:23:01 +0200 Subject: [PATCH 05/13] Refactored RpcClient functions --- libfreerdp/core/gateway/ncacn_http.c | 1 - libfreerdp/core/gateway/ntlm.c | 2 +- libfreerdp/core/gateway/ntlm.h | 1 - libfreerdp/core/gateway/rpc.c | 27 +++-------- libfreerdp/core/gateway/rpc_client.c | 72 +++++++++++++++------------- libfreerdp/core/gateway/rpc_client.h | 8 ++-- libfreerdp/core/gateway/tsg.c | 6 +-- 7 files changed, 55 insertions(+), 62 deletions(-) diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index c9d831269..2b9ef8535 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -200,7 +200,6 @@ void rpc_ncacn_http_ntlm_uninit(RpcChannel* channel) if (!channel) return; - ntlm_client_uninit(channel->ntlm); ntlm_free(channel->ntlm); channel->ntlm = NULL; } diff --git a/libfreerdp/core/gateway/ntlm.c b/libfreerdp/core/gateway/ntlm.c index 6def662a9..99306db3b 100644 --- a/libfreerdp/core/gateway/ntlm.c +++ b/libfreerdp/core/gateway/ntlm.c @@ -272,7 +272,7 @@ BOOL ntlm_authenticate(rdpNtlm* ntlm) return (status == SEC_I_CONTINUE_NEEDED) ? TRUE : FALSE; } -void ntlm_client_uninit(rdpNtlm* ntlm) +static void ntlm_client_uninit(rdpNtlm* ntlm) { free(ntlm->identity.User); ntlm->identity.User = NULL; diff --git a/libfreerdp/core/gateway/ntlm.h b/libfreerdp/core/gateway/ntlm.h index f384378c0..29bf53b9d 100644 --- a/libfreerdp/core/gateway/ntlm.h +++ b/libfreerdp/core/gateway/ntlm.h @@ -70,7 +70,6 @@ FREERDP_LOCAL BOOL ntlm_authenticate(rdpNtlm* ntlm); FREERDP_LOCAL BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL confidentiality, char* user, char* domain, char* password, SecPkgContext_Bindings* Bindings); -FREERDP_LOCAL void ntlm_client_uninit(rdpNtlm* ntlm); FREERDP_LOCAL BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, char* hostname); diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 93fd2465e..dd2e236ba 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -855,15 +855,14 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CurrentKeepAliveInterval = rpc->KeepAliveInterval; rpc->CurrentKeepAliveTime = 0; rpc->CallId = 2; + rpc->client = rpc_client_new(rpc->settings, rpc->max_recv_frag); - if (rpc_client_new(rpc) < 0) - goto out_free_rpc_client; + if (!rpc->client) + goto out_free; return rpc; -out_free_rpc_client: - rpc_client_free(rpc); out_free: - free(rpc); + rpc_free(rpc); return NULL; } @@ -871,21 +870,9 @@ void rpc_free(rdpRpc* rpc) { if (rpc) { - rpc_client_free(rpc); - - if (rpc->ntlm) - { - ntlm_client_uninit(rpc->ntlm); - ntlm_free(rpc->ntlm); - rpc->ntlm = NULL; - } - - if (rpc->VirtualConnection) - { - rpc_virtual_connection_free(rpc->VirtualConnection); - rpc->VirtualConnection = NULL; - } - + rpc_client_free(rpc->client); + ntlm_free(rpc->ntlm); + rpc_virtual_connection_free(rpc->VirtualConnection); free(rpc); } } diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index 3aa6d169f..f4b03b630 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -76,11 +76,14 @@ static void rpc_pdu_free(RPC_PDU* pdu) free(pdu); } -static int rpc_client_receive_pipe_write(rdpRpc* rpc, const BYTE* buffer, size_t length) +static int rpc_client_receive_pipe_write(RpcClient* client, const BYTE* buffer, size_t length) { int status = 0; - RpcClient* client = rpc->client; - EnterCriticalSection(&(rpc->client->PipeLock)); + + if (!client || !buffer) + return -1; + + EnterCriticalSection(&(client->PipeLock)); if (ringbuffer_write(&(client->ReceivePipe), buffer, length)) status += (int) length; @@ -88,17 +91,20 @@ static int rpc_client_receive_pipe_write(rdpRpc* rpc, const BYTE* buffer, size_t if (ringbuffer_used(&(client->ReceivePipe)) > 0) SetEvent(client->PipeEvent); - LeaveCriticalSection(&(rpc->client->PipeLock)); + LeaveCriticalSection(&(client->PipeLock)); return status; } -int rpc_client_receive_pipe_read(rdpRpc* rpc, BYTE* buffer, size_t length) +int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length) { int index = 0; int status = 0; int nchunks = 0; DataChunk chunks[2]; - RpcClient* client = rpc->client; + + if (!client || !buffer) + return -1; + EnterCriticalSection(&(client->PipeLock)); nchunks = ringbuffer_peek(&(client->ReceivePipe), chunks, length); @@ -347,7 +353,7 @@ static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment) rpc->StubCallId, header->common.call_id, rpc->StubFragCount); } - call = rpc_client_call_find_by_id(rpc, rpc->StubCallId); + call = rpc_client_call_find_by_id(rpc->client, rpc->StubCallId); if (!call) return -1; @@ -374,7 +380,7 @@ static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment) } else { - rpc_client_receive_pipe_write(rpc, &buffer[StubOffset], (size_t) StubLength); + rpc_client_receive_pipe_write(rpc->client, &buffer[StubOffset], (size_t) StubLength); rpc->StubFragCount++; if (header->response.alloc_hint == StubLength) @@ -785,23 +791,27 @@ int rpc_client_in_channel_recv(rdpRpc* rpc) * http://msdn.microsoft.com/en-us/library/gg593159/ */ -RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, UINT32 CallId) +RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT32 CallId) { int index; int count; RpcClientCall* clientCall = NULL; - ArrayList_Lock(rpc->client->ClientCallList); - count = ArrayList_Count(rpc->client->ClientCallList); + + if (!client) + return NULL; + + ArrayList_Lock(client->ClientCallList); + count = ArrayList_Count(client->ClientCallList); for (index = 0; index < count; index++) { - clientCall = (RpcClientCall*) ArrayList_GetItem(rpc->client->ClientCallList, index); + clientCall = (RpcClientCall*) ArrayList_GetItem(client->ClientCallList, index); if (clientCall->CallId == CallId) break; } - ArrayList_Unlock(rpc->client->ClientCallList); + ArrayList_Unlock(client->ClientCallList); return clientCall; } @@ -836,7 +846,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length return -1; header = (rpcconn_common_hdr_t*) buffer; - clientCall = rpc_client_call_find_by_id(rpc, header->call_id); + clientCall = rpc_client_call_find_by_id(rpc->client, header->call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; /* @@ -976,49 +986,48 @@ out_free_pdu: return -1; } -int rpc_client_new(rdpRpc* rpc) +RpcClient* rpc_client_new(rdpSettings* settings, UINT32 max_recv_frag) { - RpcClient* client; - client = (RpcClient*) calloc(1, sizeof(RpcClient)); - rpc->client = client; + RpcClient* client = (RpcClient*) calloc(1, sizeof(RpcClient)); - if (!client) - return -1; + if (!client || !settings) + goto fail; client->pdu = rpc_pdu_new(); if (!client->pdu) - return -1; + goto fail; - client->ReceiveFragment = Stream_New(NULL, rpc->max_recv_frag); + client->ReceiveFragment = Stream_New(NULL, max_recv_frag); if (!client->ReceiveFragment) - return -1; + goto fail; client->PipeEvent = CreateEvent(NULL, TRUE, FALSE, NULL); if (!client->PipeEvent) - return -1; + goto fail; if (!ringbuffer_init(&(client->ReceivePipe), 4096)) - return -1; + goto fail; if (!InitializeCriticalSectionAndSpinCount(&(client->PipeLock), 4000)) - return -1; + goto fail; client->ClientCallList = ArrayList_New(TRUE); if (!client->ClientCallList) - return -1; + goto fail; ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; - return 1; + return client; +fail: + rpc_client_free(client); + return NULL; } -void rpc_client_free(rdpRpc* rpc) +void rpc_client_free(RpcClient* client) { - RpcClient* client = rpc->client; - if (!client) return; @@ -1038,5 +1047,4 @@ void rpc_client_free(rdpRpc* rpc) ArrayList_Free(client->ClientCallList); free(client); - rpc->client = NULL; } diff --git a/libfreerdp/core/gateway/rpc_client.h b/libfreerdp/core/gateway/rpc_client.h index a3c31fc4c..24a6480b4 100644 --- a/libfreerdp/core/gateway/rpc_client.h +++ b/libfreerdp/core/gateway/rpc_client.h @@ -24,7 +24,7 @@ #include "rpc.h" -FREERDP_LOCAL RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, +FREERDP_LOCAL RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT32 CallId); FREERDP_LOCAL RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum); @@ -36,13 +36,13 @@ FREERDP_LOCAL int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, FREERDP_LOCAL int rpc_client_in_channel_recv(rdpRpc* rpc); FREERDP_LOCAL int rpc_client_out_channel_recv(rdpRpc* rpc); -FREERDP_LOCAL int rpc_client_receive_pipe_read(rdpRpc* rpc, BYTE* buffer, +FREERDP_LOCAL int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length); FREERDP_LOCAL int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum); -FREERDP_LOCAL int rpc_client_new(rdpRpc* rpc); -FREERDP_LOCAL void rpc_client_free(rdpRpc* rpc); +FREERDP_LOCAL RpcClient* rpc_client_new(rdpSettings* settings, UINT32 max_recv_frag); +FREERDP_LOCAL void rpc_client_free(RpcClient* client); #endif /* FREERDP_LIB_CORE_GATEWAY_RPC_CLIENT_H */ diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index a3000c7ae..1c7aeccfd 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -1399,7 +1399,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) break; case TSG_STATE_AUTHORIZED: - call = rpc_client_call_find_by_id(rpc, pdu->CallId); + call = rpc_client_call_find_by_id(rpc->client, pdu->CallId); if (!call) return -1; @@ -1469,7 +1469,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) break; case TSG_STATE_PIPE_CREATED: - call = rpc_client_call_find_by_id(rpc, pdu->CallId); + call = rpc_client_call_find_by_id(rpc->client, pdu->CallId); if (!call) return -1; @@ -1769,7 +1769,7 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) do { - status = rpc_client_receive_pipe_read(rpc, data, (size_t) length); + status = rpc_client_receive_pipe_read(rpc->client, data, (size_t) length); if (status < 0) return -1; From 7ab1251a674fc7664c7ddb0fd8d487ad084850c9 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Thu, 27 Sep 2018 16:42:27 +0200 Subject: [PATCH 06/13] Refactored rpc_client and resolve gateway only once. --- libfreerdp/core/gateway/rpc.c | 73 ++++++++++++++++++---------- libfreerdp/core/gateway/rpc.h | 28 ++++++----- libfreerdp/core/gateway/rpc_client.c | 45 +++++++++++++++-- libfreerdp/core/gateway/rpc_client.h | 3 +- 4 files changed, 104 insertions(+), 45 deletions(-) diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index dd2e236ba..9106023e5 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -405,16 +405,18 @@ BOOL rpc_in_channel_transition_to_state(RpcInChannel* inChannel, CLIENT_IN_CHANN return TRUE; } -static int rpc_channel_rpch_init(rdpRpc* rpc, RpcChannel* channel, const char* inout) +static int rpc_channel_rpch_init(RpcClient* client, RpcChannel* channel, const char* inout) { HttpContext* http; + rdpSettings* settings; - if (!rpc || !channel || !inout) + if (!client || !channel || !inout || !client->context || !client->context->settings) return -1; + settings = client->context->settings; channel->ntlm = ntlm_new(); rts_generate_cookie((BYTE*) &channel->Cookie); - channel->rpc = rpc; + channel->client = client; if (!channel->ntlm) return -1; @@ -431,7 +433,7 @@ static int rpc_channel_rpch_init(rdpRpc* rpc, RpcChannel* channel, const char* i !http_context_set_accept(http, "application/rpc") || !http_context_set_cache_control(http, "no-cache") || !http_context_set_connection(http, "Keep-Alive") || !http_context_set_user_agent(http, "MSRPC") || - !http_context_set_host(http, rpc->settings->GatewayHostname) || + !http_context_set_host(http, settings->GatewayHostname) || !http_context_set_pragma(http, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624")) @@ -448,7 +450,7 @@ static int rpc_in_channel_init(rdpRpc* rpc, RpcInChannel* inChannel) inChannel->PingOriginator.ConnectionTimeout = 30; inChannel->PingOriginator.KeepAliveInterval = 0; - if (rpc_channel_rpch_init(rpc, &inChannel->common, "RPC_IN_DATA") < 0) + if (rpc_channel_rpch_init(rpc->client, &inChannel->common, "RPC_IN_DATA") < 0) return -1; return 1; @@ -542,7 +544,7 @@ static int rpc_out_channel_init(rdpRpc* rpc, RpcOutChannel* outChannel) outChannel->ReceiveWindowSize = rpc->ReceiveWindow; outChannel->AvailableWindowAdvertised = rpc->ReceiveWindow; - if (rpc_channel_rpch_init(rpc, &outChannel->common, "RPC_OUT_DATA") < 0) + if (rpc_channel_rpch_init(rpc->client, &outChannel->common, "RPC_OUT_DATA") < 0) return -1; return 1; @@ -649,20 +651,26 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) int tlsStatus; BIO* socketBio; BIO* bufferedBio; - rdpRpc* rpc = channel->rpc; - rdpContext* context = rpc->context; - rdpSettings* settings = context->settings; - const char* peerHostname = settings->GatewayHostname; - UINT16 peerPort = settings->GatewayPort; - const char* proxyUsername = settings->ProxyUsername, *proxyPassword = settings->ProxyPassword; - BOOL isProxyConnection = proxy_prepare(settings, &peerHostname, &peerPort, &proxyUsername, - &proxyPassword); - sockfd = freerdp_tcp_connect(context, settings, peerHostname, - peerPort, timeout); + rdpContext* context; + rdpSettings* settings; + const char* proxyUsername; + const char* proxyPassword; - if (sockfd < 0) - return -1; + if (!channel || !channel->client || !channel->client->context || + !channel->client->context->settings) + return FALSE; + context = channel->client->context; + settings = context->settings; + proxyUsername = settings->ProxyUsername; + proxyPassword = settings->ProxyPassword; + { + sockfd = freerdp_tcp_connect(context, settings, channel->client->host, + channel->client->port, timeout); + + if (sockfd < 0) + return FALSE; + } socketBio = BIO_new(BIO_s_simple_socket()); if (!socketBio) @@ -679,7 +687,7 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) if (!BIO_set_nonblock(bufferedBio, TRUE)) return -1; - if (isProxyConnection) + if (channel->client->isProxy) { if (!proxy_connect(settings, bufferedBio, proxyUsername, proxyPassword, settings->GatewayHostname, settings->GatewayPort)) @@ -718,7 +726,12 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) { - rdpRpc* rpc = inChannel->common.rpc; + rdpContext* context; + + if (!inChannel || !inChannel->common.client || !inChannel->common.client->context) + return -1; + + context = inChannel->common.client->context; /* Connect IN Channel */ @@ -727,7 +740,7 @@ static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED); - if (!rpc_ncacn_http_ntlm_init(rpc->context, &inChannel->common)) + if (!rpc_ncacn_http_ntlm_init(context, &inChannel->common)) return -1; /* Send IN Channel Request */ @@ -746,7 +759,10 @@ static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) { - rdpRpc* rpc = outChannel->common.rpc; + rdpContext* context; + + if (!outChannel || !outChannel->common.client || !outChannel->common.client->context) + return -1; /* Connect OUT Channel */ @@ -755,7 +771,7 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); - if (!rpc_ncacn_http_ntlm_init(rpc->context, &outChannel->common)) + if (!rpc_ncacn_http_ntlm_init(context, &outChannel->common)) return FALSE; /* Send OUT Channel Request */ @@ -772,16 +788,19 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) { - rdpRpc* rpc = outChannel->common.rpc; + rdpContext* context; + + if (!outChannel || !outChannel->common.client || !outChannel->common.client->context) + return -1; /* Connect OUT Channel */ - if (rpc_channel_tls_connect((RpcChannel*) outChannel, timeout) < 0) + if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) return -1; rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); - if (!rpc_ncacn_http_ntlm_init(rpc->context, (RpcChannel*) outChannel)) + if (!rpc_ncacn_http_ntlm_init(context, (RpcChannel*) outChannel)) return FALSE; /* Send OUT Channel Request */ @@ -855,7 +874,7 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CurrentKeepAliveInterval = rpc->KeepAliveInterval; rpc->CurrentKeepAliveTime = 0; rpc->CallId = 2; - rpc->client = rpc_client_new(rpc->settings, rpc->max_recv_frag); + rpc->client = rpc_client_new(rpc->context, rpc->max_recv_frag); if (!rpc->client) goto out_free; diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index 4aebd178b..49e8c0ded 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -585,9 +585,24 @@ struct rpc_client_call }; typedef struct rpc_client_call RpcClientCall; +struct rpc_client +{ + rdpContext* context; + RPC_PDU* pdu; + HANDLE PipeEvent; + RingBuffer ReceivePipe; + wStream* ReceiveFragment; + CRITICAL_SECTION PipeLock; + wArrayList* ClientCallList; + char* host; + UINT16 port; + BOOL isProxy; +}; +typedef struct rpc_client RpcClient; + struct rpc_channel { - rdpRpc* rpc; + RpcClient* client; BIO* bio; rdpTls* tls; rdpNtlm* ntlm; @@ -714,17 +729,6 @@ struct rpc_virtual_connection_cookie_entry typedef struct rpc_virtual_connection_cookie_entry RpcVirtualConnectionCookieEntry; -struct rpc_client -{ - RPC_PDU* pdu; - HANDLE PipeEvent; - RingBuffer ReceivePipe; - wStream* ReceiveFragment; - CRITICAL_SECTION PipeLock; - wArrayList* ClientCallList; -}; -typedef struct rpc_client RpcClient; - struct rdp_rpc { RPC_CLIENT_STATE State; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index f4b03b630..fec8cc6a0 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -36,6 +36,7 @@ #include "rpc_fault.h" #include "rpc_client.h" #include "../rdp.h" +#include "../proxy.h" #define TAG FREERDP_TAG("core.gateway.rpc") @@ -839,14 +840,13 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length int status; RpcClientCall* clientCall; rpcconn_common_hdr_t* header; - rdpRpc* rpc = inChannel->common.rpc; status = rpc_channel_write(&inChannel->common, buffer, length); if (status <= 0) return -1; header = (rpcconn_common_hdr_t*) buffer; - clientCall = rpc_client_call_find_by_id(rpc->client, header->call_id); + clientCall = rpc_client_call_find_by_id(inChannel->common.client, header->call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; /* @@ -986,11 +986,44 @@ out_free_pdu: return -1; } -RpcClient* rpc_client_new(rdpSettings* settings, UINT32 max_recv_frag) +static BOOL rpc_client_resolve_gateway(rdpSettings* settings, char** host, UINT16* port, + BOOL* isProxy) +{ + struct addrinfo* result; + + if (!settings || !host || !port || !isProxy) + return FALSE; + else + { + const char* peerHostname = settings->GatewayHostname; + const char* proxyUsername = settings->ProxyUsername; + const char* proxyPassword = settings->ProxyPassword; + *port = settings->GatewayPort; + *isProxy = proxy_prepare(settings, &peerHostname, port, &proxyUsername, &proxyPassword); + result = freerdp_tcp_resolve_host(peerHostname, *port, 0); + + if (!result) + return FALSE; + + *host = freerdp_tcp_address_to_string(result->ai_addr, NULL); + freeaddrinfo(result); + return TRUE; + } +} + +RpcClient* rpc_client_new(rdpContext* context, UINT32 max_recv_frag) { RpcClient* client = (RpcClient*) calloc(1, sizeof(RpcClient)); - if (!client || !settings) + if (!client) + return NULL; + + if (!rpc_client_resolve_gateway(context->settings, &client->host, &client->port, &client->isProxy)) + goto fail; + + client->context = context; + + if (!client->context) goto fail; client->pdu = rpc_pdu_new(); @@ -1019,7 +1052,7 @@ RpcClient* rpc_client_new(rdpSettings* settings, UINT32 max_recv_frag) if (!client->ClientCallList) goto fail; - ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; + ArrayList_Object(client->ClientCallList)->fnObjectFree = rpc_client_call_free; return client; fail: rpc_client_free(client); @@ -1031,6 +1064,8 @@ void rpc_client_free(RpcClient* client) if (!client) return; + free(client->host); + if (client->ReceiveFragment) Stream_Free(client->ReceiveFragment, TRUE); diff --git a/libfreerdp/core/gateway/rpc_client.h b/libfreerdp/core/gateway/rpc_client.h index 24a6480b4..559d78faa 100644 --- a/libfreerdp/core/gateway/rpc_client.h +++ b/libfreerdp/core/gateway/rpc_client.h @@ -20,6 +20,7 @@ #ifndef FREERDP_LIB_CORE_GATEWAY_RPC_CLIENT_H #define FREERDP_LIB_CORE_GATEWAY_RPC_CLIENT_H +#include #include #include "rpc.h" @@ -42,7 +43,7 @@ FREERDP_LOCAL int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, FREERDP_LOCAL int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum); -FREERDP_LOCAL RpcClient* rpc_client_new(rdpSettings* settings, UINT32 max_recv_frag); +FREERDP_LOCAL RpcClient* rpc_client_new(rdpContext* context, UINT32 max_recv_frag); FREERDP_LOCAL void rpc_client_free(RpcClient* client); #endif /* FREERDP_LIB_CORE_GATEWAY_RPC_CLIENT_H */ From 9516c251c715caa16af215bf9d44c2f0513e779d Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Fri, 28 Sep 2018 08:43:43 +0200 Subject: [PATCH 07/13] Made TSG struct opaque --- libfreerdp/core/gateway/rpc.c | 4 + libfreerdp/core/gateway/rpc_client.c | 65 +- libfreerdp/core/gateway/rpc_client.h | 4 +- libfreerdp/core/gateway/tsg.c | 1671 +++++++++++++++----------- libfreerdp/core/gateway/tsg.h | 209 +--- libfreerdp/core/rdp.c | 9 +- libfreerdp/core/transport.c | 2 +- 7 files changed, 999 insertions(+), 965 deletions(-) diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 9106023e5..68a32b952 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -764,6 +764,8 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) if (!outChannel || !outChannel->common.client || !outChannel->common.client->context) return -1; + context = outChannel->common.client->context; + /* Connect OUT Channel */ if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) @@ -793,6 +795,8 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) if (!outChannel || !outChannel->common.client || !outChannel->common.client->context) return -1; + context = outChannel->common.client->context; + /* Connect OUT Channel */ if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index fec8cc6a0..af814e093 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -273,7 +273,7 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) rpc_client_transition_to_state(rpc, RPC_CLIENT_STATE_CONTEXT_NEGOTIATED); - if (tsg_proxy_begin(tsg) < 0) + if (!tsg_proxy_begin(tsg)) { WLog_ERR(TAG, "tsg_proxy_begin failure"); return -1; @@ -288,7 +288,10 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) } else if (rpc->State >= RPC_CLIENT_STATE_CONTEXT_NEGOTIATED) { - status = tsg_recv_pdu(tsg, pdu); + if (!tsg_recv_pdu(tsg, pdu)) + status = -1; + else + status = 1; } return status; @@ -331,7 +334,7 @@ static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment) TerminateEventArgs e; rpc->result = *((UINT32*) &buffer[StubOffset]); freerdp_abort_connect(rpc->context->instance); - rpc->transport->tsg->state = TSG_STATE_TUNNEL_CLOSE_PENDING; + tsg_set_state(rpc->transport->tsg, TSG_STATE_TUNNEL_CLOSE_PENDING); EventArgsInit(&e, "freerdp"); e.code = 0; PubSub_OnTerminate(rpc->context->pubSub, rpc->context, &e); @@ -865,7 +868,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length return status; } -int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) +BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) { SECURITY_STATUS status; UINT32 offset; @@ -873,19 +876,36 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) UINT32 stub_data_pad; SecBuffer Buffers[2]; SecBufferDesc Message; - RpcClientCall* clientCall; - rdpNtlm* ntlm = rpc->ntlm; + RpcClientCall* clientCall = NULL; + rdpNtlm* ntlm; SECURITY_STATUS encrypt_status; rpcconn_request_hdr_t* request_pdu = NULL; - RpcVirtualConnection* connection = rpc->VirtualConnection; - RpcInChannel* inChannel = connection->DefaultInChannel; + RpcVirtualConnection* connection; + RpcInChannel* inChannel; + size_t length; + + if (!s || !rpc) + return FALSE; + + ntlm = rpc->ntlm; + connection = rpc->VirtualConnection; if (!ntlm || !ntlm->table) { WLog_ERR(TAG, "invalid ntlm context"); - return -1; + return FALSE; } + if (!connection) + return FALSE; + + inChannel = connection->DefaultInChannel; + + if (!inChannel) + return FALSE; + + Stream_SealLength(s); + length = Stream_Length(s); status = ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes); @@ -893,14 +913,14 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) { WLog_ERR(TAG, "QueryContextAttributes SECPKG_ATTR_SIZES failure %s [0x%08"PRIX32"]", GetSecurityStatusString(status), status); - return -1; + return FALSE; } ZeroMemory(&Buffers, sizeof(Buffers)); request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t)); if (!request_pdu) - return -1; + return FALSE; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); request_pdu->ptype = PTYPE_REQUEST; @@ -913,15 +933,15 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum); if (!clientCall) - goto out_free_pdu; + goto fail; if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) - goto out_free_clientCall; + goto fail; if (request_pdu->opnum == TsProxySetupReceivePipeOpnum) rpc->PipeCallId = request_pdu->call_id; - request_pdu->stub_data = data; + request_pdu->stub_data = Stream_Buffer(s); offset = 24; stub_data_pad = rpc_offset_align(&offset, 8); offset += length; @@ -935,7 +955,7 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) buffer = (BYTE*) calloc(1, request_pdu->frag_length); if (!buffer) - goto out_free_pdu; + goto fail; CopyMemory(buffer, request_pdu, 24); offset = 24; @@ -953,7 +973,7 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer); if (!Buffers[1].pvBuffer) - goto out_free_pdu; + goto fail; Message.cBuffers = 2; Message.ulVersion = SECBUFFER_VERSION; @@ -964,7 +984,7 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) { WLog_ERR(TAG, "EncryptMessage status %s [0x%08"PRIX32"]", GetSecurityStatusString(encrypt_status), encrypt_status); - goto out_free_pdu; + goto fail; } CopyMemory(&buffer[offset], Buffers[1].pvBuffer, Buffers[1].cbBuffer); @@ -972,18 +992,19 @@ int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) free(Buffers[1].pvBuffer); if (rpc_in_channel_send_pdu(inChannel, buffer, request_pdu->frag_length) < 0) - length = -1; + goto fail; free(request_pdu); free(buffer); - return length; -out_free_clientCall: + Stream_Free(s, TRUE); + return TRUE; +fail: rpc_client_call_free(clientCall); -out_free_pdu: free(buffer); free(Buffers[1].pvBuffer); free(request_pdu); - return -1; + Stream_Free(s, TRUE); + return FALSE; } static BOOL rpc_client_resolve_gateway(rdpSettings* settings, char** host, UINT16* port, diff --git a/libfreerdp/core/gateway/rpc_client.h b/libfreerdp/core/gateway/rpc_client.h index 559d78faa..3055d34f0 100644 --- a/libfreerdp/core/gateway/rpc_client.h +++ b/libfreerdp/core/gateway/rpc_client.h @@ -21,6 +21,7 @@ #define FREERDP_LIB_CORE_GATEWAY_RPC_CLIENT_H #include +#include #include #include "rpc.h" @@ -40,8 +41,7 @@ FREERDP_LOCAL int rpc_client_out_channel_recv(rdpRpc* rpc); FREERDP_LOCAL int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length); -FREERDP_LOCAL int rpc_client_write_call(rdpRpc* rpc, BYTE* data, int length, - UINT16 opnum); +FREERDP_LOCAL BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum); FREERDP_LOCAL RpcClient* rpc_client_new(rdpContext* context, UINT32 max_recv_frag); FREERDP_LOCAL void rpc_client_free(RpcClient* client); diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index 1c7aeccfd..5d519819a 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -39,6 +39,196 @@ #define TAG FREERDP_TAG("core.gateway.tsg") +typedef WCHAR* RESOURCENAME; + +typedef struct _tsendpointinfo +{ + RESOURCENAME* resourceName; + UINT32 numResourceNames; + RESOURCENAME* alternateResourceNames; + UINT16 numAlternateResourceNames; + UINT32 Port; +} TSENDPOINTINFO, *PTSENDPOINTINFO; + +typedef struct _TSG_PACKET_HEADER +{ + UINT16 ComponentId; + UINT16 PacketId; +} TSG_PACKET_HEADER, *PTSG_PACKET_HEADER; + +typedef struct _TSG_CAPABILITY_NAP +{ + UINT32 capabilities; +} TSG_CAPABILITY_NAP, *PTSG_CAPABILITY_NAP; + +typedef union +{ + TSG_CAPABILITY_NAP tsgCapNap; +} TSG_CAPABILITIES_UNION, *PTSG_CAPABILITIES_UNION; + +typedef struct _TSG_PACKET_CAPABILITIES +{ + UINT32 capabilityType; + TSG_CAPABILITIES_UNION tsgPacket; +} TSG_PACKET_CAPABILITIES, *PTSG_PACKET_CAPABILITIES; + +typedef struct _TSG_PACKET_VERSIONCAPS +{ + TSG_PACKET_HEADER tsgHeader; + PTSG_PACKET_CAPABILITIES tsgCaps; + UINT32 numCapabilities; + UINT16 majorVersion; + UINT16 minorVersion; + UINT16 quarantineCapabilities; +} TSG_PACKET_VERSIONCAPS, *PTSG_PACKET_VERSIONCAPS; + +typedef struct _TSG_PACKET_QUARCONFIGREQUEST +{ + UINT32 flags; +} TSG_PACKET_QUARCONFIGREQUEST, *PTSG_PACKET_QUARCONFIGREQUEST; + +typedef struct _TSG_PACKET_QUARREQUEST +{ + UINT32 flags; + WCHAR* machineName; + UINT32 nameLength; + BYTE* data; + UINT32 dataLen; +} TSG_PACKET_QUARREQUEST, *PTSG_PACKET_QUARREQUEST; + +typedef struct _TSG_REDIRECTION_FLAGS +{ + BOOL enableAllRedirections; + BOOL disableAllRedirections; + BOOL driveRedirectionDisabled; + BOOL printerRedirectionDisabled; + BOOL portRedirectionDisabled; + BOOL reserved; + BOOL clipboardRedirectionDisabled; + BOOL pnpRedirectionDisabled; +} TSG_REDIRECTION_FLAGS, *PTSG_REDIRECTION_FLAGS; + +typedef struct _TSG_PACKET_RESPONSE +{ + UINT32 flags; + UINT32 reserved; + BYTE* responseData; + UINT32 responseDataLen; + TSG_REDIRECTION_FLAGS redirectionFlags; +} TSG_PACKET_RESPONSE, *PTSG_PACKET_RESPONSE; + +typedef struct _TSG_PACKET_QUARENC_RESPONSE +{ + UINT32 flags; + UINT32 certChainLen; + WCHAR* certChainData; + GUID nonce; + PTSG_PACKET_VERSIONCAPS versionCaps; +} TSG_PACKET_QUARENC_RESPONSE, *PTSG_PACKET_QUARENC_RESPONSE; + +typedef struct TSG_PACKET_STRING_MESSAGE +{ + INT32 isDisplayMandatory; + INT32 isConsentMandatory; + UINT32 msgBytes; + WCHAR* msgBuffer; +} TSG_PACKET_STRING_MESSAGE, *PTSG_PACKET_STRING_MESSAGE; + +typedef struct TSG_PACKET_REAUTH_MESSAGE +{ + UINT64 tunnelContext; +} TSG_PACKET_REAUTH_MESSAGE, *PTSG_PACKET_REAUTH_MESSAGE; + +typedef union +{ + PTSG_PACKET_STRING_MESSAGE consentMessage; + PTSG_PACKET_STRING_MESSAGE serviceMessage; + PTSG_PACKET_REAUTH_MESSAGE reauthMessage; +} TSG_PACKET_TYPE_MESSAGE_UNION, *PTSG_PACKET_TYPE_MESSAGE_UNION; + +typedef struct _TSG_PACKET_MSG_RESPONSE +{ + UINT32 msgID; + UINT32 msgType; + INT32 isMsgPresent; + TSG_PACKET_TYPE_MESSAGE_UNION messagePacket; +} TSG_PACKET_MSG_RESPONSE, *PTSG_PACKET_MSG_RESPONSE; + +typedef struct TSG_PACKET_CAPS_RESPONSE +{ + TSG_PACKET_QUARENC_RESPONSE pktQuarEncResponse; + TSG_PACKET_MSG_RESPONSE pktConsentMessage; +} TSG_PACKET_CAPS_RESPONSE, *PTSG_PACKET_CAPS_RESPONSE; + +typedef struct TSG_PACKET_MSG_REQUEST +{ + UINT32 maxMessagesPerBatch; +} TSG_PACKET_MSG_REQUEST, *PTSG_PACKET_MSG_REQUEST; + +typedef struct _TSG_PACKET_AUTH +{ + TSG_PACKET_VERSIONCAPS tsgVersionCaps; + UINT32 cookieLen; + BYTE* cookie; +} TSG_PACKET_AUTH, *PTSG_PACKET_AUTH; + +typedef union +{ + PTSG_PACKET_VERSIONCAPS packetVersionCaps; + PTSG_PACKET_AUTH packetAuth; +} TSG_INITIAL_PACKET_TYPE_UNION, *PTSG_INITIAL_PACKET_TYPE_UNION; + +typedef struct TSG_PACKET_REAUTH +{ + UINT64 tunnelContext; + UINT32 packetId; + TSG_INITIAL_PACKET_TYPE_UNION tsgInitialPacket; +} TSG_PACKET_REAUTH, *PTSG_PACKET_REAUTH; + +typedef union +{ + PTSG_PACKET_HEADER packetHeader; + PTSG_PACKET_VERSIONCAPS packetVersionCaps; + PTSG_PACKET_QUARCONFIGREQUEST packetQuarConfigRequest; + PTSG_PACKET_QUARREQUEST packetQuarRequest; + PTSG_PACKET_RESPONSE packetResponse; + PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse; + PTSG_PACKET_CAPS_RESPONSE packetCapsResponse; + PTSG_PACKET_MSG_REQUEST packetMsgRequest; + PTSG_PACKET_MSG_RESPONSE packetMsgResponse; + PTSG_PACKET_AUTH packetAuth; + PTSG_PACKET_REAUTH packetReauth; +} TSG_PACKET_TYPE_UNION; + +typedef struct _TSG_PACKET +{ + UINT32 packetId; + TSG_PACKET_TYPE_UNION tsgPacket; +} TSG_PACKET, *PTSG_PACKET; + +struct rdp_tsg +{ + BIO* bio; + rdpRpc* rpc; + UINT16 Port; + LPWSTR Hostname; + LPWSTR MachineName; + TSG_STATE state; + UINT32 TunnelId; + UINT32 ChannelId; + BOOL reauthSequence; + rdpSettings* settings; + rdpTransport* transport; + UINT64 ReauthTunnelContext; + CONTEXT_HANDLE TunnelContext; + CONTEXT_HANDLE ChannelContext; + CONTEXT_HANDLE NewTunnelContext; + CONTEXT_HANDLE NewChannelContext; + TSG_PACKET_REAUTH packetReauth; + TSG_PACKET_CAPABILITIES tsgCaps; + TSG_PACKET_VERSIONCAPS packetVersionCaps; +}; + static BIO_METHOD* BIO_s_tsg(void); /** * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/ @@ -76,17 +266,15 @@ static BIO_METHOD* BIO_s_tsg(void); * TsProxySendToServerRequest(ChannelContext) */ -static DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count, - UINT32* lengths) +static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UINT32 count, + UINT32* lengths) { wStream* s; - int status; rdpTsg* tsg; - BYTE* buffer; - UINT32 length; - byte* buffer1 = NULL; - byte* buffer2 = NULL; - byte* buffer3 = NULL; + int length; + const byte* buffer1 = NULL; + const byte* buffer2 = NULL; + const byte* buffer3 = NULL; UINT32 buffer1Length; UINT32 buffer2Length; UINT32 buffer3Length; @@ -120,16 +308,10 @@ static DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 } length = 28 + totalDataBytes; - buffer = (BYTE*) calloc(1, length); - - if (!buffer) - return -1; - - s = Stream_New(buffer, length); + s = Stream_New(NULL, length); if (!s) { - free(buffer); WLog_ERR(TAG, "Stream_New failed!"); return -1; } @@ -158,16 +340,8 @@ static DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 if (buffer3Length > 0) Stream_Write(s, buffer3, buffer3Length); /* buffer3 (variable) */ - Stream_SealLength(s); - status = rpc_client_write_call(tsg->rpc, Stream_Buffer(s), Stream_Length(s), - TsProxySendToServerOpnum); - Stream_Free(s, TRUE); - - if (status <= 0) - { - WLog_ERR(TAG, "rpc_write failed!"); + if (!rpc_client_write_call(tsg->rpc, s, TsProxySendToServerOpnum)) return -1; - } return length; } @@ -185,161 +359,154 @@ static DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 static BOOL TsProxyCreateTunnelWriteRequest(rdpTsg* tsg, PTSG_PACKET tsgPacket) { - int status; - UINT32 length; - UINT32 offset = 0; - BYTE* buffer = NULL; - rdpRpc* rpc = tsg->rpc; + BOOL rc = FALSE; + BOOL write = TRUE; + UINT16 opnum = 0; + wStream* s; + rdpRpc* rpc; + + if (!tsg || !tsg->rpc) + return FALSE; + + rpc = tsg->rpc; WLog_DBG(TAG, "TsProxyCreateTunnelWriteRequest"); + s = Stream_New(NULL, 108); - if (tsgPacket->packetId == TSG_PACKET_TYPE_VERSIONCAPS) + if (!s) + return FALSE; + + switch (tsgPacket->packetId) { - PTSG_PACKET_VERSIONCAPS packetVersionCaps = tsgPacket->tsgPacket.packetVersionCaps; - PTSG_CAPABILITY_NAP tsgCapNap = &packetVersionCaps->tsgCaps->tsgPacket.tsgCapNap; - length = 108; - buffer = (BYTE*) malloc(length); + case TSG_PACKET_TYPE_VERSIONCAPS: + { + PTSG_PACKET_VERSIONCAPS packetVersionCaps = tsgPacket->tsgPacket.packetVersionCaps; + PTSG_CAPABILITY_NAP tsgCapNap = &packetVersionCaps->tsgCaps->tsgPacket.tsgCapNap; + Stream_Write_UINT32(s, tsgPacket->packetId); /* PacketId (4 bytes) */ + Stream_Write_UINT32(s, tsgPacket->packetId); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, 0x00020000); /* PacketVersionCapsPtr (4 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->tsgHeader.ComponentId); /* ComponentId (2 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->tsgHeader.PacketId); /* PacketId (2 bytes) */ + Stream_Write_UINT32(s, 0x00020004); /* TsgCapsPtr (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->numCapabilities); /* NumCapabilities (4 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->majorVersion); /* MajorVersion (2 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->minorVersion); /* MinorVersion (2 bytes) */ + Stream_Write_UINT16(s, + packetVersionCaps->quarantineCapabilities); /* QuarantineCapabilities (2 bytes) */ + /* 4-byte alignment (30 + 2) */ + Stream_Write_UINT16(s, 0x0000); /* pad (2 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->numCapabilities); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->tsgCaps->capabilityType); /* CapabilityType (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->tsgCaps->capabilityType); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, tsgCapNap->capabilities); /* capabilities (4 bytes) */ + /** + * The following 60-byte structure is apparently undocumented, + * but parts of it can be matched to known C706 data structures. + */ + /* + * 8-byte constant (8A E3 13 71 02 F4 36 71) also observed here: + * http://lists.samba.org/archive/cifs-protocol/2010-July/001543.html + */ + Stream_Write_UINT8(s, 0x8A); + Stream_Write_UINT8(s, 0xE3); + Stream_Write_UINT8(s, 0x13); + Stream_Write_UINT8(s, 0x71); + Stream_Write_UINT8(s, 0x02); + Stream_Write_UINT8(s, 0xF4); + Stream_Write_UINT8(s, 0x36); + Stream_Write_UINT8(s, 0x71); + Stream_Write_UINT32(s, 0x00040001); /* 1.4 (version?) */ + Stream_Write_UINT32(s, 0x00000001); /* 1 (element count?) */ + /* p_cont_list_t */ + Stream_Write_UINT8(s, 2); /* ncontext_elem */ + Stream_Write_UINT8(s, 0x40); /* reserved1 */ + Stream_Write_UINT16(s, 0x0028); /* reserved2 */ + /* p_syntax_id_t */ + Stream_Write(s, &TSGU_UUID, sizeof(p_uuid_t)); + Stream_Write_UINT32(s, TSGU_SYNTAX_IF_VERSION); + /* p_syntax_id_t */ + Stream_Write(s, &NDR_UUID, sizeof(p_uuid_t)); + Stream_Write_UINT32(s, NDR_SYNTAX_IF_VERSION); + opnum = TsProxyCreateTunnelOpnum; + } + break; - if (!buffer) - return FALSE; + case TSG_PACKET_TYPE_REAUTH: + { + PTSG_PACKET_REAUTH packetReauth = tsgPacket->tsgPacket.packetReauth; + PTSG_PACKET_VERSIONCAPS packetVersionCaps = packetReauth->tsgInitialPacket.packetVersionCaps; + PTSG_CAPABILITY_NAP tsgCapNap = &packetVersionCaps->tsgCaps->tsgPacket.tsgCapNap; + Stream_Write_UINT32(s, tsgPacket->packetId); /* PacketId (4 bytes) */ + Stream_Write_UINT32(s, tsgPacket->packetId); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, 0x00020000); /* PacketReauthPtr (4 bytes) */ + Stream_Write_UINT32(s, 0); /* ??? (4 bytes) */ + Stream_Write_UINT64(s, packetReauth->tunnelContext); /* TunnelContext (8 bytes) */ + Stream_Write_UINT32(s, TSG_PACKET_TYPE_VERSIONCAPS); /* PacketId (4 bytes) */ + Stream_Write_UINT32(s, TSG_PACKET_TYPE_VERSIONCAPS); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, 0x00020004); /* PacketVersionCapsPtr (4 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->tsgHeader.ComponentId); /* ComponentId (2 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->tsgHeader.PacketId); /* PacketId (2 bytes) */ + Stream_Write_UINT32(s, 0x00020008); /* TsgCapsPtr (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->numCapabilities); /* NumCapabilities (4 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->majorVersion); /* MajorVersion (2 bytes) */ + Stream_Write_UINT16(s, packetVersionCaps->minorVersion); /* MinorVersion (2 bytes) */ + Stream_Write_UINT16(s, + packetVersionCaps->quarantineCapabilities); /* QuarantineCapabilities (2 bytes) */ + /* 4-byte alignment (30 + 2) */ + Stream_Write_UINT16(s, 0x0000); /* pad (2 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->numCapabilities); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->tsgCaps->capabilityType); /* CapabilityType (4 bytes) */ + Stream_Write_UINT32(s, packetVersionCaps->tsgCaps->capabilityType); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, tsgCapNap->capabilities); /* capabilities (4 bytes) */ + opnum = TsProxyCreateTunnelOpnum; + } + break; - *((UINT32*) &buffer[0]) = tsgPacket->packetId; /* PacketId (4 bytes) */ - *((UINT32*) &buffer[4]) = tsgPacket->packetId; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[8]) = 0x00020000; /* PacketVersionCapsPtr (4 bytes) */ - *((UINT16*) &buffer[12]) = packetVersionCaps->tsgHeader.ComponentId; /* ComponentId (2 bytes) */ - *((UINT16*) &buffer[14]) = packetVersionCaps->tsgHeader.PacketId; /* PacketId (2 bytes) */ - *((UINT32*) &buffer[16]) = 0x00020004; /* TsgCapsPtr (4 bytes) */ - *((UINT32*) &buffer[20]) = packetVersionCaps->numCapabilities; /* NumCapabilities (4 bytes) */ - *((UINT16*) &buffer[24]) = packetVersionCaps->majorVersion; /* MajorVersion (2 bytes) */ - *((UINT16*) &buffer[26]) = packetVersionCaps->minorVersion; /* MinorVersion (2 bytes) */ - *((UINT16*) &buffer[28]) = - packetVersionCaps->quarantineCapabilities; /* QuarantineCapabilities (2 bytes) */ - /* 4-byte alignment (30 + 2) */ - *((UINT16*) &buffer[30]) = 0x0000; /* pad (2 bytes) */ - *((UINT32*) &buffer[32]) = packetVersionCaps->numCapabilities; /* MaxCount (4 bytes) */ - *((UINT32*) &buffer[36]) = - packetVersionCaps->tsgCaps->capabilityType; /* CapabilityType (4 bytes) */ - *((UINT32*) &buffer[40]) = packetVersionCaps->tsgCaps->capabilityType; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[44]) = tsgCapNap->capabilities; /* capabilities (4 bytes) */ - offset = 48; - /** - * The following 60-byte structure is apparently undocumented, - * but parts of it can be matched to known C706 data structures. - */ - /* - * 8-byte constant (8A E3 13 71 02 F4 36 71) also observed here: - * http://lists.samba.org/archive/cifs-protocol/2010-July/001543.html - */ - buffer[offset + 0] = 0x8A; - buffer[offset + 1] = 0xE3; - buffer[offset + 2] = 0x13; - buffer[offset + 3] = 0x71; - buffer[offset + 4] = 0x02; - buffer[offset + 5] = 0xF4; - buffer[offset + 6] = 0x36; - buffer[offset + 7] = 0x71; - *((UINT32*) &buffer[offset + 8]) = 0x00040001; /* 1.4 (version?) */ - *((UINT32*) &buffer[offset + 12]) = 0x00000001; /* 1 (element count?) */ - /* p_cont_list_t */ - buffer[offset + 16] = 2; /* ncontext_elem */ - buffer[offset + 17] = 0x40; /* reserved1 */ - *((UINT16*) &buffer[offset + 18]) = 0x0028; /* reserved2 */ - /* p_syntax_id_t */ - CopyMemory(&buffer[offset + 20], &TSGU_UUID, sizeof(p_uuid_t)); - *((UINT32*) &buffer[offset + 36]) = TSGU_SYNTAX_IF_VERSION; - /* p_syntax_id_t */ - CopyMemory(&buffer[offset + 40], &NDR_UUID, sizeof(p_uuid_t)); - *((UINT32*) &buffer[offset + 56]) = NDR_SYNTAX_IF_VERSION; - status = rpc_client_write_call(rpc, buffer, length, TsProxyCreateTunnelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - } - else if (tsgPacket->packetId == TSG_PACKET_TYPE_REAUTH) - { - PTSG_PACKET_REAUTH packetReauth = tsgPacket->tsgPacket.packetReauth; - PTSG_PACKET_VERSIONCAPS packetVersionCaps = packetReauth->tsgInitialPacket.packetVersionCaps; - PTSG_CAPABILITY_NAP tsgCapNap = &packetVersionCaps->tsgCaps->tsgPacket.tsgCapNap; - length = 72; - buffer = (BYTE*) malloc(length); - - if (!buffer) - return FALSE; - - *((UINT32*) &buffer[0]) = tsgPacket->packetId; /* PacketId (4 bytes) */ - *((UINT32*) &buffer[4]) = tsgPacket->packetId; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[8]) = 0x00020000; /* PacketReauthPtr (4 bytes) */ - *((UINT32*) &buffer[12]) = 0; /* ??? (4 bytes) */ - *((UINT64*) &buffer[16]) = packetReauth->tunnelContext; /* TunnelContext (8 bytes) */ - offset = 24; - *((UINT32*) &buffer[offset + 0]) = TSG_PACKET_TYPE_VERSIONCAPS; /* PacketId (4 bytes) */ - *((UINT32*) &buffer[offset + 4]) = TSG_PACKET_TYPE_VERSIONCAPS; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[offset + 8]) = 0x00020004; /* PacketVersionCapsPtr (4 bytes) */ - *((UINT16*) &buffer[offset + 12]) = - packetVersionCaps->tsgHeader.ComponentId; /* ComponentId (2 bytes) */ - *((UINT16*) &buffer[offset + 14]) = packetVersionCaps->tsgHeader.PacketId; /* PacketId (2 bytes) */ - *((UINT32*) &buffer[offset + 16]) = 0x00020008; /* TsgCapsPtr (4 bytes) */ - *((UINT32*) &buffer[offset + 20]) = - packetVersionCaps->numCapabilities; /* NumCapabilities (4 bytes) */ - *((UINT16*) &buffer[offset + 24]) = packetVersionCaps->majorVersion; /* MajorVersion (2 bytes) */ - *((UINT16*) &buffer[offset + 26]) = packetVersionCaps->minorVersion; /* MinorVersion (2 bytes) */ - *((UINT16*) &buffer[offset + 28]) = - packetVersionCaps->quarantineCapabilities; /* QuarantineCapabilities (2 bytes) */ - /* 4-byte alignment (30 + 2) */ - *((UINT16*) &buffer[offset + 30]) = 0x0000; /* pad (2 bytes) */ - *((UINT32*) &buffer[offset + 32]) = packetVersionCaps->numCapabilities; /* MaxCount (4 bytes) */ - *((UINT32*) &buffer[offset + 36]) = - packetVersionCaps->tsgCaps->capabilityType; /* CapabilityType (4 bytes) */ - *((UINT32*) &buffer[offset + 40]) = - packetVersionCaps->tsgCaps->capabilityType; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[offset + 44]) = tsgCapNap->capabilities; /* capabilities (4 bytes) */ - status = rpc_client_write_call(rpc, buffer, length, TsProxyCreateTunnelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; + default: + write = FALSE; + break; } - return TRUE; + rc = TRUE; + + if (write) + return rpc_client_write_call(rpc, s, opnum); + + Stream_Free(s, TRUE); + return rc; } static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* tunnelContext, UINT32* tunnelId) { - BYTE* buffer; + BOOL rc = FALSE; UINT32 count; - UINT32 length; - UINT32 offset; UINT32 Pointer; PTSG_PACKET packet; UINT32 SwitchValue; UINT32 MessageSwitchValue = 0; UINT32 IsMessagePresent; UINT32 MsgBytes; - PTSG_PACKET_CAPABILITIES tsgCaps; - PTSG_PACKET_VERSIONCAPS versionCaps; - PTSG_PACKET_CAPS_RESPONSE packetCapsResponse; - PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse; + PTSG_PACKET_CAPABILITIES tsgCaps = NULL; + PTSG_PACKET_VERSIONCAPS versionCaps = NULL; + PTSG_PACKET_CAPS_RESPONSE packetCapsResponse = NULL; + PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse = NULL; WLog_DBG(TAG, "TsProxyCreateTunnelReadResponse"); if (!pdu) return FALSE; - length = Stream_Length(pdu->s); - buffer = Stream_Buffer(pdu->s); - - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; - packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); if (!packet) return FALSE; - offset = 4; /* PacketPtr (4 bytes) */ - packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId (4 bytes) */ - SwitchValue = *((UINT32*) &buffer[offset + 4]); /* SwitchValue (4 bytes) */ + if (Stream_GetRemainingLength(pdu->s) < 12) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* PacketPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packet->packetId); /* PacketId (4 bytes) */ + Stream_Read_UINT32(pdu->s, SwitchValue); /* SwitchValue (4 bytes) */ if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE)) @@ -347,177 +514,188 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE)); if (!packetCapsResponse) - { - free(packet); - return FALSE; - } + goto fail; packet->tsgPacket.packetCapsResponse = packetCapsResponse; - /* PacketQuarResponsePtr (4 bytes) */ - packetCapsResponse->pktQuarEncResponse.flags = *((UINT32*) &buffer[offset + - 12]); /* Flags (4 bytes) */ - packetCapsResponse->pktQuarEncResponse.certChainLen = *((UINT32*) &buffer[offset + - 16]); /* CertChainLength (4 bytes) */ - /* CertChainDataPtr (4 bytes) */ - CopyMemory(&packetCapsResponse->pktQuarEncResponse.nonce, &buffer[offset + 24], - 16); /* Nonce (16 bytes) */ - offset += 40; - Pointer = *((UINT32*) &buffer[offset]); /* VersionCapsPtr (4 bytes) */ - offset += 4; + + if (Stream_GetRemainingLength(pdu->s) < 32) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* PacketQuarResponsePtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetCapsResponse->pktQuarEncResponse.flags); /* Flags (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetCapsResponse->pktQuarEncResponse.certChainLen); /* CertChainLength (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* CertChainDataPtr (4 bytes) */ + Stream_Read(pdu->s, &packetCapsResponse->pktQuarEncResponse.nonce, 16); /* Nonce (16 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* VersionCapsPtr (4 bytes) */ if ((Pointer == 0x0002000C) || (Pointer == 0x00020008)) { - offset += 4; /* MsgId (4 bytes) */ - offset += 4; /* MsgType (4 bytes) */ - IsMessagePresent = *((UINT32*) &buffer[offset]); /* IsMessagePresent (4 bytes) */ - offset += 4; - MessageSwitchValue = *((UINT32*) &buffer[offset]); /* MessageSwitchValue (4 bytes) */ - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 16) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* MsgId (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* MsgType (4 bytes) */ + Stream_Read_UINT32(pdu->s, IsMessagePresent); /* IsMessagePresent (4 bytes) */ + Stream_Read_UINT32(pdu->s, MessageSwitchValue); /* MessageSwitchValue (4 bytes) */ } if (packetCapsResponse->pktQuarEncResponse.certChainLen > 0) { - Pointer = *((UINT32*) &buffer[offset]); /* MsgPtr (4 bytes): 0x00020014 */ - offset += 4; - offset += 4; /* MaxCount (4 bytes) */ - offset += 4; /* Offset (4 bytes) */ - count = *((UINT32*) &buffer[offset]); /* ActualCount (4 bytes) */ - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 16) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* MsgPtr (4 bytes): 0x00020014 */ + Stream_Seek_UINT32(pdu->s); /* MaxCount (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Offset (4 bytes) */ + Stream_Read_UINT32(pdu->s, count); /* ActualCount (4 bytes) */ + /* * CertChainData is a wide character string, and the count is * given in characters excluding the null terminator, therefore: * size = (count * 2) */ - offset += (count * 2); /* CertChainData */ + if (!Stream_SafeSeek(pdu->s, count * 2)) /* CertChainData */ + goto fail; + /* 4-byte alignment */ - rpc_offset_align(&offset, 4); + { + UINT32 offset = Stream_Pointer(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) + goto fail; + } } else { - Pointer = *((UINT32*) &buffer[offset]); /* Ptr (4 bytes) */ - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 4) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* Ptr (4 bytes) */ } versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); if (!versionCaps) - { - free(packetCapsResponse); - free(packet); - return FALSE; - } + goto fail; packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps; - versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId (2 bytes) */ - versionCaps->tsgHeader.PacketId = *((UINT16*) &buffer[offset + 2]); /* PacketId (2 bytes) */ - offset += 4; + + if (Stream_GetRemainingLength(pdu->s) < 18) + goto fail; + + Stream_Read_UINT16(pdu->s, versionCaps->tsgHeader.ComponentId); /* ComponentId (2 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->tsgHeader.PacketId); /* PacketId (2 bytes) */ if (versionCaps->tsgHeader.ComponentId != TS_GATEWAY_TRANSPORT) { WLog_ERR(TAG, "Unexpected ComponentId: 0x%04"PRIX16", Expected TS_GATEWAY_TRANSPORT", versionCaps->tsgHeader.ComponentId); - free(packetCapsResponse); - free(versionCaps); - free(packet); - return FALSE; + goto fail; } - Pointer = *((UINT32*) &buffer[offset]); /* TsgCapsPtr (4 bytes) */ - versionCaps->numCapabilities = *((UINT32*) &buffer[offset + 4]); /* NumCapabilities (4 bytes) */ - versionCaps->majorVersion = *((UINT16*) &buffer[offset + 8]); /* MajorVersion (2 bytes) */ - versionCaps->minorVersion = *((UINT16*) &buffer[offset + 10]); /* MinorVersion (2 bytes) */ - versionCaps->quarantineCapabilities = *((UINT16*) &buffer[offset + - 12]); /* QuarantineCapabilities (2 bytes) */ - offset += 14; + Stream_Read_UINT32(pdu->s, Pointer); /* TsgCapsPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, versionCaps->numCapabilities); /* NumCapabilities (4 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->majorVersion); /* MajorVersion (2 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->minorVersion); /* MinorVersion (2 bytes) */ + Stream_Read_UINT16(pdu->s, + versionCaps->quarantineCapabilities); /* QuarantineCapabilities (2 bytes) */ /* 4-byte alignment */ - rpc_offset_align(&offset, 4); + { + UINT32 offset = Stream_Pointer(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) + goto fail; + } tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES)); if (!tsgCaps) - { - free(packetCapsResponse); - free(versionCaps); - free(packet); - return FALSE; - } + goto fail; versionCaps->tsgCaps = tsgCaps; - offset += 4; /* MaxCount (4 bytes) */ - tsgCaps->capabilityType = *((UINT32*) &buffer[offset]); /* CapabilityType (4 bytes) */ - SwitchValue = *((UINT32*) &buffer[offset + 4]); /* SwitchValue (4 bytes) */ - offset += 8; + + if (Stream_GetRemainingLength(pdu->s) < 16) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* MaxCount (4 bytes) */ + Stream_Read_UINT32(pdu->s, tsgCaps->capabilityType); /* CapabilityType (4 bytes) */ + Stream_Read_UINT32(pdu->s, SwitchValue); /* SwitchValue (4 bytes) */ if ((SwitchValue != TSG_CAPABILITY_TYPE_NAP) || (tsgCaps->capabilityType != TSG_CAPABILITY_TYPE_NAP)) { WLog_ERR(TAG, "Unexpected CapabilityType: 0x%08"PRIX32", Expected TSG_CAPABILITY_TYPE_NAP", tsgCaps->capabilityType); - free(tsgCaps); - free(versionCaps); - free(packetCapsResponse); - free(packet); - return FALSE; + goto fail; } - tsgCaps->tsgPacket.tsgCapNap.capabilities = *((UINT32*) - &buffer[offset]); /* Capabilities (4 bytes) */ - offset += 4; + Stream_Read_UINT32(pdu->s, tsgCaps->tsgPacket.tsgCapNap.capabilities); /* Capabilities (4 bytes) */ switch (MessageSwitchValue) { case TSG_ASYNC_MESSAGE_CONSENT_MESSAGE: case TSG_ASYNC_MESSAGE_SERVICE_MESSAGE: - offset += 4; /* IsDisplayMandatory (4 bytes) */ - offset += 4; /* IsConsent Mandatory (4 bytes) */ - MsgBytes = *((UINT32*) &buffer[offset]); - offset += 4; - Pointer = *((UINT32*) &buffer[offset]); - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 16) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* IsDisplayMandatory (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* IsConsent Mandatory (4 bytes) */ + Stream_Read_UINT32(pdu->s, MsgBytes); + Stream_Read_UINT32(pdu->s, Pointer); if (Pointer) { - offset += 4; /* MaxCount (4 bytes) */ - offset += 4; /* Offset (4 bytes) */ - offset += 4; /* Length (4 bytes) */ + if (Stream_GetRemainingLength(pdu->s) < 12) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* MaxCount (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Offset (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Length (4 bytes) */ } if (MsgBytes > TSG_MESSAGING_MAX_MESSAGE_LENGTH) { WLog_ERR(TAG, "Out of Spec Message Length %"PRIu32"", MsgBytes); - free(tsgCaps); - free(versionCaps); - free(packetCapsResponse); - free(packet); - return FALSE; + goto fail; } - offset += MsgBytes; + if (!Stream_SafeSeek(pdu->s, MsgBytes)) + goto fail; + break; case TSG_ASYNC_MESSAGE_REAUTH: - rpc_offset_align(&offset, 8); - offset += 8; /* TunnelContext (8 bytes) */ + { + UINT32 offset = Stream_Pointer(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 8) || + (Stream_GetRemainingLength(pdu->s) < 8))) + goto fail; + + Stream_Seek_UINT64(pdu->s); /* TunnelContext (8 bytes) */ + } break; default: WLog_ERR(TAG, "Unexpected Message Type: 0x%"PRIX32"", MessageSwitchValue); - free(tsgCaps); - free(versionCaps); - free(packetCapsResponse); - free(packet); - return FALSE; + goto fail; + } + + { + UINT32 offset = Stream_GetPosition(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) + goto fail; } - rpc_offset_align(&offset, 4); /* TunnelContext (20 bytes) */ - CopyMemory(&tunnelContext->ContextType, &buffer[offset], 4); /* ContextType (4 bytes) */ - CopyMemory(&tunnelContext->ContextUuid, &buffer[offset + 4], 16); /* ContextUuid (16 bytes) */ - offset += 20; - *tunnelId = *((UINT32*) &buffer[offset]); /* TunnelId (4 bytes) */ + if (Stream_GetRemainingLength(pdu->s) < 24) + goto fail; + + Stream_Read_UINT32(pdu->s, tunnelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Read(pdu->s, tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Read_UINT32(pdu->s, *tunnelId); /* TunnelId (4 bytes) */ /* ReturnValue (4 bytes) */ - free(tsgCaps); - free(versionCaps); - free(packetCapsResponse); } else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE)) @@ -526,98 +704,114 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, sizeof(TSG_PACKET_QUARENC_RESPONSE)); if (!packetQuarEncResponse) - { - free(packet); - return FALSE; - } + goto fail; packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse; - /* PacketQuarResponsePtr (4 bytes) */ - packetQuarEncResponse->flags = *((UINT32*) &buffer[offset + 12]); /* Flags (4 bytes) */ - packetQuarEncResponse->certChainLen = *((UINT32*) &buffer[offset + - 16]); /* CertChainLength (4 bytes) */ - /* CertChainDataPtr (4 bytes) */ - CopyMemory(&packetQuarEncResponse->nonce, &buffer[offset + 24], 16); /* Nonce (16 bytes) */ - offset += 40; + + if (Stream_GetRemainingLength(pdu->s) < 32) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* PacketQuarResponsePtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetQuarEncResponse->flags); /* Flags (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetQuarEncResponse->certChainLen); /* CertChainLength (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* CertChainDataPtr (4 bytes) */ + Stream_Read(pdu->s, &packetQuarEncResponse->nonce, 16); /* Nonce (16 bytes) */ if (packetQuarEncResponse->certChainLen > 0) { - Pointer = *((UINT32*) &buffer[offset]); /* Ptr (4 bytes): 0x0002000C */ - offset += 4; - offset += 4; /* MaxCount (4 bytes) */ - offset += 4; /* Offset (4 bytes) */ - count = *((UINT32*) &buffer[offset]); /* ActualCount (4 bytes) */ - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 16) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* Ptr (4 bytes): 0x0002000C */ + Stream_Seek_UINT32(pdu->s); /* MaxCount (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Offset (4 bytes) */ + Stream_Read_UINT32(pdu->s, count); /* ActualCount (4 bytes) */ + /* * CertChainData is a wide character string, and the count is * given in characters excluding the null terminator, therefore: * size = (count * 2) */ - offset += (count * 2); /* CertChainData */ + if (!Stream_SafeSeek(pdu->s, count * 2)) /* CertChainData */ + goto fail; + /* 4-byte alignment */ - rpc_offset_align(&offset, 4); + { + UINT32 offset = Stream_GetPosition(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) + goto fail; + } } else { - Pointer = *((UINT32*) &buffer[offset]); /* Ptr (4 bytes): 0x00020008 */ - offset += 4; + if (Stream_GetRemainingLength(pdu->s) < 4) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* Ptr (4 bytes): 0x00020008 */ } versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); if (!versionCaps) - { - free(packetQuarEncResponse); - free(packet); - return FALSE; - } + goto fail; packetQuarEncResponse->versionCaps = versionCaps; - versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId (2 bytes) */ - versionCaps->tsgHeader.PacketId = *((UINT16*) &buffer[offset + 2]); /* PacketId (2 bytes) */ - offset += 4; + + if (Stream_GetRemainingLength(pdu->s) < 18) + goto fail; + + Stream_Read_UINT16(pdu->s, versionCaps->tsgHeader.ComponentId); /* ComponentId (2 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->tsgHeader.PacketId); /* PacketId (2 bytes) */ if (versionCaps->tsgHeader.ComponentId != TS_GATEWAY_TRANSPORT) { WLog_ERR(TAG, "Unexpected ComponentId: 0x%04"PRIX16", Expected TS_GATEWAY_TRANSPORT", versionCaps->tsgHeader.ComponentId); - free(versionCaps); - free(packetQuarEncResponse); - free(packet); - return FALSE; + goto fail; } - Pointer = *((UINT32*) &buffer[offset]); /* TsgCapsPtr (4 bytes) */ - versionCaps->numCapabilities = *((UINT32*) &buffer[offset + 4]); /* NumCapabilities (4 bytes) */ - versionCaps->majorVersion = *((UINT16*) &buffer[offset + 8]); /* MajorVersion (2 bytes) */ - versionCaps->majorVersion = *((UINT16*) &buffer[offset + 10]); /* MinorVersion (2 bytes) */ - versionCaps->quarantineCapabilities = *((UINT16*) &buffer[offset + - 12]); /* QuarantineCapabilities (2 bytes) */ - offset += 14; + Stream_Read_UINT32(pdu->s, Pointer); /* TsgCapsPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, versionCaps->numCapabilities); /* NumCapabilities (4 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->majorVersion); /* MajorVersion (2 bytes) */ + Stream_Read_UINT16(pdu->s, versionCaps->majorVersion); /* MinorVersion (2 bytes) */ + Stream_Read_UINT16(pdu->s, + versionCaps->quarantineCapabilities); /* QuarantineCapabilities (2 bytes) */ /* 4-byte alignment */ - rpc_offset_align(&offset, 4); + { + UINT32 offset = Stream_GetPosition(pdu->s); + + if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) + goto fail; + } + + if (Stream_GetRemainingLength(pdu->s) < 36) + goto fail; + /* Not sure exactly what this is */ - offset += 4; /* 0x00000001 (4 bytes) */ - offset += 4; /* 0x00000001 (4 bytes) */ - offset += 4; /* 0x00000001 (4 bytes) */ - offset += 4; /* 0x00000002 (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* 0x00000001 (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* 0x00000001 (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* 0x00000001 (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* 0x00000002 (4 bytes) */ /* TunnelContext (20 bytes) */ - CopyMemory(&tunnelContext->ContextType, &buffer[offset], 4); /* ContextType (4 bytes) */ - CopyMemory(&tunnelContext->ContextUuid, &buffer[offset + 4], 16); /* ContextUuid (16 bytes) */ - offset += 20; - free(versionCaps); - free(packetQuarEncResponse); + Stream_Read_UINT32(pdu->s, tunnelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Read(pdu->s, tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ } else { WLog_ERR(TAG, "Unexpected PacketId: 0x%08"PRIX32", Expected TSG_PACKET_TYPE_CAPS_RESPONSE " "or TSG_PACKET_TYPE_QUARENC_RESPONSE", packet->packetId); - free(packet); - return FALSE; + goto fail; } + rc = TRUE; +fail: + free(packetQuarEncResponse); + free(packetCapsResponse); + free(versionCaps); + free(tsgCaps); free(packet); - return TRUE; + return rc; } /** @@ -634,170 +828,145 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnelContext) { UINT32 pad; - int status; - BYTE* buffer; - UINT32 count; - UINT32 length; + wStream* s; + size_t count; UINT32 offset; - rdpRpc* rpc = tsg->rpc; - WLog_DBG(TAG, "TsProxyAuthorizeTunnelWriteRequest"); - count = _wcslen(tsg->MachineName) + 1; - offset = 64 + (count * 2); - rpc_offset_align(&offset, 4); - offset += 4; - length = offset; - buffer = (BYTE*) malloc(length); + rdpRpc* rpc; - if (!buffer) + if (!tsg || !tsg->rpc || !tunnelContext || !tsg->MachineName) + return FALSE; + + count = _wcslen(tsg->MachineName) + 1; + rpc = tsg->rpc; + WLog_DBG(TAG, "TsProxyAuthorizeTunnelWriteRequest"); + s = Stream_New(NULL, 1024); + + if (!s) return FALSE; /* TunnelContext (20 bytes) */ - CopyMemory(&buffer[0], &tunnelContext->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Write_UINT32(s, tunnelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, &tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ /* 4-byte alignment */ - *((UINT32*) &buffer[20]) = TSG_PACKET_TYPE_QUARREQUEST; /* PacketId (4 bytes) */ - *((UINT32*) &buffer[24]) = TSG_PACKET_TYPE_QUARREQUEST; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[28]) = 0x00020000; /* PacketQuarRequestPtr (4 bytes) */ - *((UINT32*) &buffer[32]) = 0x00000000; /* Flags (4 bytes) */ - *((UINT32*) &buffer[36]) = 0x00020004; /* MachineNamePtr (4 bytes) */ - *((UINT32*) &buffer[40]) = count; /* NameLength (4 bytes) */ - *((UINT32*) &buffer[44]) = 0x00020008; /* DataPtr (4 bytes) */ - *((UINT32*) &buffer[48]) = 0; /* DataLength (4 bytes) */ + Stream_Write_UINT32(s, TSG_PACKET_TYPE_QUARREQUEST); /* PacketId (4 bytes) */ + Stream_Write_UINT32(s, TSG_PACKET_TYPE_QUARREQUEST); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, 0x00020000); /* PacketQuarRequestPtr (4 bytes) */ + Stream_Write_UINT32(s, 0x00000000); /* Flags (4 bytes) */ + Stream_Write_UINT32(s, 0x00020004); /* MachineNamePtr (4 bytes) */ + Stream_Write_UINT32(s, count); /* NameLength (4 bytes) */ + Stream_Write_UINT32(s, 0x00020008); /* DataPtr (4 bytes) */ + Stream_Write_UINT32(s, 0); /* DataLength (4 bytes) */ /* MachineName */ - *((UINT32*) &buffer[52]) = count; /* MaxCount (4 bytes) */ - *((UINT32*) &buffer[56]) = 0; /* Offset (4 bytes) */ - *((UINT32*) &buffer[60]) = count; /* ActualCount (4 bytes) */ - CopyMemory(&buffer[64], tsg->MachineName, count * 2); /* Array */ - offset = 64 + (count * 2); + Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */ + Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */ + Stream_Write_UTF16_String(s, tsg->MachineName, count); /* Array */ /* 4-byte alignment */ + offset = Stream_GetPosition(s); pad = rpc_offset_align(&offset, 4); - ZeroMemory(&buffer[offset - pad], pad); - *((UINT32*) &buffer[offset]) = 0x00000000; /* MaxCount (4 bytes) */ - offset += 4; - status = rpc_client_write_call(rpc, buffer, length, TsProxyAuthorizeTunnelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Zero(s, pad); + Stream_Write_UINT32(s, 0x00000000); /* MaxCount (4 bytes) */ + Stream_SealLength(s); + return rpc_client_write_call(rpc, s, TsProxyAuthorizeTunnelOpnum); } static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) { - BYTE* buffer; - UINT32 length; - UINT32 offset; + BOOL rc = FALSE; UINT32 Pointer; UINT32 SizeValue; UINT32 SwitchValue; UINT32 idleTimeout; - PTSG_PACKET packet; - PTSG_PACKET_RESPONSE packetResponse; + PTSG_PACKET packet = NULL; + PTSG_PACKET_RESPONSE packetResponse = NULL; WLog_DBG(TAG, "TsProxyAuthorizeTunnelReadResponse"); if (!pdu) return FALSE; - length = Stream_Length(pdu->s); - buffer = Stream_Buffer(pdu->s); - - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; - packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); if (!packet) return FALSE; - offset = 4; /* PacketPtr (4 bytes) */ - packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId (4 bytes) */ - SwitchValue = *((UINT32*) &buffer[offset + 4]); /* SwitchValue (4 bytes) */ + if (Stream_GetRemainingLength(pdu->s) < 68) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* PacketPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packet->packetId); /* PacketId (4 bytes) */ + Stream_Read_UINT32(pdu->s, SwitchValue); /* SwitchValue (4 bytes) */ if (packet->packetId == E_PROXY_NAP_ACCESSDENIED) { WLog_ERR(TAG, "status: E_PROXY_NAP_ACCESSDENIED (0x%08X)", E_PROXY_NAP_ACCESSDENIED); WLog_ERR(TAG, "Ensure that the Gateway Connection Authorization Policy is correct"); - free(packet); - return FALSE; + goto fail; } if ((packet->packetId != TSG_PACKET_TYPE_RESPONSE) || (SwitchValue != TSG_PACKET_TYPE_RESPONSE)) { WLog_ERR(TAG, "Unexpected PacketId: 0x%08"PRIX32", Expected TSG_PACKET_TYPE_RESPONSE", packet->packetId); - free(packet); - return FALSE; + goto fail; } packetResponse = (PTSG_PACKET_RESPONSE) calloc(1, sizeof(TSG_PACKET_RESPONSE)); if (!packetResponse) - { - free(packet); - return FALSE; - } + goto fail; packet->tsgPacket.packetResponse = packetResponse; - Pointer = *((UINT32*) &buffer[offset + 8]); /* PacketResponsePtr (4 bytes) */ - packetResponse->flags = *((UINT32*) &buffer[offset + 12]); /* Flags (4 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* PacketResponsePtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetResponse->flags); /* Flags (4 bytes) */ if (packetResponse->flags != TSG_PACKET_TYPE_QUARREQUEST) { WLog_ERR(TAG, "Unexpected Packet Response Flags: 0x%08"PRIX32", Expected TSG_PACKET_TYPE_QUARREQUEST", packetResponse->flags); - free(packet); - free(packetResponse); - return FALSE; + goto fail; } - /* Reserved (4 bytes) */ - Pointer = *((UINT32*) &buffer[offset + 20]); /* ResponseDataPtr (4 bytes) */ - packetResponse->responseDataLen = *((UINT32*) &buffer[offset + - 24]); /* ResponseDataLength (4 bytes) */ - packetResponse->redirectionFlags.enableAllRedirections = *((UINT32*) &buffer[offset + - 28]); /* EnableAllRedirections (4 bytes) */ - packetResponse->redirectionFlags.disableAllRedirections = *((UINT32*) &buffer[offset + - 32]); /* DisableAllRedirections (4 bytes) */ - packetResponse->redirectionFlags.driveRedirectionDisabled = *((UINT32*) &buffer[offset + - 36]); /* DriveRedirectionDisabled (4 bytes) */ - packetResponse->redirectionFlags.printerRedirectionDisabled = *((UINT32*) &buffer[offset + - 40]); /* PrinterRedirectionDisabled (4 bytes) */ - packetResponse->redirectionFlags.portRedirectionDisabled = *((UINT32*) &buffer[offset + - 44]); /* PortRedirectionDisabled (4 bytes) */ - packetResponse->redirectionFlags.reserved = *((UINT32*) &buffer[offset + - 48]); /* Reserved (4 bytes) */ - packetResponse->redirectionFlags.clipboardRedirectionDisabled = *((UINT32*) &buffer[offset + - 52]); /* ClipboardRedirectionDisabled (4 bytes) */ - packetResponse->redirectionFlags.pnpRedirectionDisabled = *((UINT32*) &buffer[offset + - 56]); /* PnpRedirectionDisabled (4 bytes) */ - offset += 60; - SizeValue = *((UINT32*) &buffer[offset]); /* (4 bytes) */ - offset += 4; + Stream_Seek_UINT32(pdu->s); /* Reserved (4 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* ResponseDataPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetResponse->responseDataLen); /* ResponseDataLength (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.enableAllRedirections); /* EnableAllRedirections (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.disableAllRedirections); /* DisableAllRedirections (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.driveRedirectionDisabled); /* DriveRedirectionDisabled (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.printerRedirectionDisabled); /* PrinterRedirectionDisabled (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.portRedirectionDisabled); /* PortRedirectionDisabled (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags.reserved); /* Reserved (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.clipboardRedirectionDisabled); /* ClipboardRedirectionDisabled (4 bytes) */ + Stream_Read_UINT32(pdu->s, + packetResponse->redirectionFlags.pnpRedirectionDisabled); /* PnpRedirectionDisabled (4 bytes) */ + Stream_Read_UINT32(pdu->s, SizeValue); /* (4 bytes) */ if (SizeValue != packetResponse->responseDataLen) { WLog_ERR(TAG, "Unexpected size value: %"PRIu32", expected: %"PRIu32"", SizeValue, packetResponse->responseDataLen); - free(packetResponse); - free(packet); - return FALSE; + goto fail; } + if (Stream_GetRemainingLength(pdu->s) < SizeValue) + goto fail; + if (SizeValue == 4) - { - idleTimeout = *((UINT32*) &buffer[offset]); - offset += 4; - } + Stream_Read_UINT32(pdu->s, idleTimeout); else - { - offset += SizeValue; /* ResponseData */ - } + Stream_Seek(pdu->s, SizeValue); /* ResponseData */ + rc = TRUE; +fail: free(packetResponse); free(packet); - return TRUE; + return rc; } /** @@ -814,46 +983,39 @@ static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) static BOOL TsProxyMakeTunnelCallWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnelContext, UINT32 procId) { - int status; - BYTE* buffer; - UINT32 length; - rdpRpc* rpc = tsg->rpc; - WLog_DBG(TAG, "TsProxyMakeTunnelCallWriteRequest"); - length = 40; - buffer = (BYTE*) malloc(length); + wStream* s; + rdpRpc* rpc; - if (!buffer) + if (!tsg || !tsg->rpc || !tunnelContext) + return FALSE; + + rpc = tsg->rpc; + WLog_DBG(TAG, "TsProxyMakeTunnelCallWriteRequest"); + s = Stream_New(NULL, 40); + + if (!s) return FALSE; /* TunnelContext (20 bytes) */ - CopyMemory(&buffer[0], &tunnelContext->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ - *((UINT32*) &buffer[20]) = procId; /* ProcId (4 bytes) */ + Stream_Write_UINT32(s, tunnelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Write_UINT32(s, procId); /* ProcId (4 bytes) */ /* 4-byte alignment */ - *((UINT32*) &buffer[24]) = TSG_PACKET_TYPE_MSGREQUEST_PACKET; /* PacketId (4 bytes) */ - *((UINT32*) &buffer[28]) = TSG_PACKET_TYPE_MSGREQUEST_PACKET; /* SwitchValue (4 bytes) */ - *((UINT32*) &buffer[32]) = 0x00020000; /* PacketMsgRequestPtr (4 bytes) */ - *((UINT32*) &buffer[36]) = 0x00000001; /* MaxMessagesPerBatch (4 bytes) */ - status = rpc_client_write_call(rpc, buffer, length, TsProxyMakeTunnelCallOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Write_UINT32(s, TSG_PACKET_TYPE_MSGREQUEST_PACKET); /* PacketId (4 bytes) */ + Stream_Write_UINT32(s, TSG_PACKET_TYPE_MSGREQUEST_PACKET); /* SwitchValue (4 bytes) */ + Stream_Write_UINT32(s, 0x00020000); /* PacketMsgRequestPtr (4 bytes) */ + Stream_Write_UINT32(s, 0x00000001); /* MaxMessagesPerBatch (4 bytes) */ + return rpc_client_write_call(rpc, s, TsProxyMakeTunnelCallOpnum); } static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) { - BYTE* buffer; - UINT32 length; - UINT32 offset; + BOOL rc = FALSE; UINT32 Pointer; UINT32 MaxCount; UINT32 ActualCount; UINT32 SwitchValue; PTSG_PACKET packet; - BOOL status = TRUE; char* messageText = NULL; PTSG_PACKET_MSG_RESPONSE packetMsgResponse = NULL; PTSG_PACKET_STRING_MESSAGE packetStringMessage = NULL; @@ -865,44 +1027,37 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!pdu) return FALSE; - length = Stream_Length(pdu->s); - buffer = Stream_Buffer(pdu->s); - - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; - packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); if (!packet) return FALSE; - offset = 4; /* PacketPtr (4 bytes) */ - packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId (4 bytes) */ - SwitchValue = *((UINT32*) &buffer[offset + 4]); /* SwitchValue (4 bytes) */ + if (Stream_GetRemainingLength(pdu->s) < 32) + goto fail; + + Stream_Seek_UINT32(pdu->s); /* PacketPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packet->packetId); /* PacketId (4 bytes) */ + Stream_Read_UINT32(pdu->s, SwitchValue); /* SwitchValue (4 bytes) */ if ((packet->packetId != TSG_PACKET_TYPE_MESSAGE_PACKET) || (SwitchValue != TSG_PACKET_TYPE_MESSAGE_PACKET)) { WLog_ERR(TAG, "Unexpected PacketId: 0x%08"PRIX32", Expected TSG_PACKET_TYPE_MESSAGE_PACKET", packet->packetId); - free(packet); - return FALSE; + goto fail; } packetMsgResponse = (PTSG_PACKET_MSG_RESPONSE) calloc(1, sizeof(TSG_PACKET_MSG_RESPONSE)); if (!packetMsgResponse) - { - free(packet); - return FALSE; - } + goto fail; packet->tsgPacket.packetMsgResponse = packetMsgResponse; - Pointer = *((UINT32*) &buffer[offset + 8]); /* PacketMsgResponsePtr (4 bytes) */ - packetMsgResponse->msgID = *((UINT32*) &buffer[offset + 12]); /* MsgId (4 bytes) */ - packetMsgResponse->msgType = *((UINT32*) &buffer[offset + 16]); /* MsgType (4 bytes) */ - packetMsgResponse->isMsgPresent = *((INT32*) &buffer[offset + 20]); /* IsMsgPresent (4 bytes) */ - SwitchValue = *((UINT32*) &buffer[offset + 24]); /* SwitchValue (4 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* PacketMsgResponsePtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetMsgResponse->msgID); /* MsgId (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetMsgResponse->msgType); /* MsgType (4 bytes) */ + Stream_Read_INT32(pdu->s, packetMsgResponse->isMsgPresent); /* IsMsgPresent (4 bytes) */ + Stream_Read_UINT32(pdu->s, SwitchValue); /* SwitchValue (4 bytes) */ switch (SwitchValue) { @@ -910,24 +1065,30 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) packetStringMessage = (PTSG_PACKET_STRING_MESSAGE) calloc(1, sizeof(TSG_PACKET_STRING_MESSAGE)); if (!packetStringMessage) - { - status = FALSE; - goto out; - } + goto fail; packetMsgResponse->messagePacket.consentMessage = packetStringMessage; - Pointer = *((UINT32*) &buffer[offset + 28]); /* ConsentMessagePtr (4 bytes) */ - packetStringMessage->isDisplayMandatory = *((INT32*) &buffer[offset + - 32]); /* IsDisplayMandatory (4 bytes) */ - packetStringMessage->isConsentMandatory = *((INT32*) &buffer[offset + - 36]); /* IsConsentMandatory (4 bytes) */ - packetStringMessage->msgBytes = *((UINT32*) &buffer[offset + 40]); /* MsgBytes (4 bytes) */ - Pointer = *((UINT32*) &buffer[offset + 44]); /* MsgPtr (4 bytes) */ - MaxCount = *((UINT32*) &buffer[offset + 48]); /* MaxCount (4 bytes) */ - /* Offset (4 bytes) */ - ActualCount = *((UINT32*) &buffer[offset + 56]); /* ActualCount (4 bytes) */ - ConvertFromUnicode(CP_UTF8, 0, (WCHAR*) &buffer[offset + 60], ActualCount, &messageText, 0, NULL, + + if (Stream_GetRemainingLength(pdu->s) < 32) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* ConsentMessagePtr (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetStringMessage->isDisplayMandatory); /* IsDisplayMandatory (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetStringMessage->isConsentMandatory); /* IsConsentMandatory (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetStringMessage->msgBytes); /* MsgBytes (4 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* MsgPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, MaxCount); /* MaxCount (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Offset (4 bytes) */ + Stream_Read_UINT32(pdu->s, ActualCount); /* ActualCount (4 bytes) */ + + if (Stream_GetRemainingLength(pdu->s) < ActualCount * 2) + goto fail; + + ConvertFromUnicode(CP_UTF8, 0, (WCHAR*) Stream_Pointer(pdu->s), ActualCount, &messageText, 0, NULL, NULL); + Stream_Seek(pdu->s, ActualCount * 2); WLog_INFO(TAG, "Consent Message: %s", messageText); free(messageText); break; @@ -936,24 +1097,30 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) packetStringMessage = (PTSG_PACKET_STRING_MESSAGE) calloc(1, sizeof(TSG_PACKET_STRING_MESSAGE)); if (!packetStringMessage) - { - status = FALSE; - goto out; - } + goto fail; packetMsgResponse->messagePacket.serviceMessage = packetStringMessage; - Pointer = *((UINT32*) &buffer[offset + 28]); /* ServiceMessagePtr (4 bytes) */ - packetStringMessage->isDisplayMandatory = *((INT32*) &buffer[offset + - 32]); /* IsDisplayMandatory (4 bytes) */ - packetStringMessage->isConsentMandatory = *((INT32*) &buffer[offset + - 36]); /* IsConsentMandatory (4 bytes) */ - packetStringMessage->msgBytes = *((UINT32*) &buffer[offset + 40]); /* MsgBytes (4 bytes) */ - Pointer = *((UINT32*) &buffer[offset + 44]); /* MsgPtr (4 bytes) */ - MaxCount = *((UINT32*) &buffer[offset + 48]); /* MaxCount (4 bytes) */ - /* Offset (4 bytes) */ - ActualCount = *((UINT32*) &buffer[offset + 56]); /* ActualCount (4 bytes) */ - ConvertFromUnicode(CP_UTF8, 0, (WCHAR*) &buffer[offset + 60], ActualCount, &messageText, 0, NULL, + + if (Stream_GetRemainingLength(pdu->s) < 32) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* ServiceMessagePtr (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetStringMessage->isDisplayMandatory); /* IsDisplayMandatory (4 bytes) */ + Stream_Read_INT32(pdu->s, + packetStringMessage->isConsentMandatory); /* IsConsentMandatory (4 bytes) */ + Stream_Read_UINT32(pdu->s, packetStringMessage->msgBytes); /* MsgBytes (4 bytes) */ + Stream_Read_UINT32(pdu->s, Pointer); /* MsgPtr (4 bytes) */ + Stream_Read_UINT32(pdu->s, MaxCount); /* MaxCount (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* Offset (4 bytes) */ + Stream_Read_UINT32(pdu->s, ActualCount); /* ActualCount (4 bytes) */ + + if (Stream_GetRemainingLength(pdu->s) < ActualCount * 2) + goto fail; + + ConvertFromUnicode(CP_UTF8, 0, (WCHAR*) Stream_Pointer(pdu->s), ActualCount, &messageText, 0, NULL, NULL); + Stream_Seek(pdu->s, ActualCount * 2); WLog_INFO(TAG, "Service Message: %s", messageText); free(messageText); break; @@ -962,40 +1129,32 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) packetReauthMessage = (PTSG_PACKET_REAUTH_MESSAGE) calloc(1, sizeof(TSG_PACKET_REAUTH_MESSAGE)); if (!packetReauthMessage) - { - status = FALSE; - goto out; - } + goto fail; packetMsgResponse->messagePacket.reauthMessage = packetReauthMessage; - Pointer = *((UINT32*) &buffer[offset + 28]); /* ReauthMessagePtr (4 bytes) */ - /* alignment pad (4 bytes) */ - packetReauthMessage->tunnelContext = *((UINT64*) &buffer[offset + - 36]); /* TunnelContext (8 bytes) */ - /* ReturnValue (4 bytes) */ + + if (Stream_GetRemainingLength(pdu->s) < 20) + goto fail; + + Stream_Read_UINT32(pdu->s, Pointer); /* ReauthMessagePtr (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* alignment pad (4 bytes) */ + Stream_Read_UINT64(pdu->s, packetReauthMessage->tunnelContext); /* TunnelContext (8 bytes) */ + Stream_Seek_UINT32(pdu->s); /* ReturnValue (4 bytes) */ tsg->ReauthTunnelContext = packetReauthMessage->tunnelContext; break; default: WLog_ERR(TAG, "unexpected message type: %"PRIu32"", SwitchValue); - status = FALSE; - break; + goto fail; } -out: - - if (packet) - { - if (packet->tsgPacket.packetMsgResponse) - { - free(packet->tsgPacket.packetMsgResponse->messagePacket.reauthMessage); - free(packet->tsgPacket.packetMsgResponse); - } - - free(packet); - } - - return status; + rc = TRUE; +fail: + free(packetStringMessage); + free(packetReauthMessage); + free(packetMsgResponse); + free(packet); + return rc; } /** @@ -1011,70 +1170,63 @@ out: static BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnelContext) { - int status; - UINT32 count; - BYTE* buffer; - UINT32 length; - rdpRpc* rpc = tsg->rpc; - count = _wcslen(tsg->Hostname) + 1; + size_t count; + wStream* s; + rdpRpc* rpc; WLog_DBG(TAG, "TsProxyCreateChannelWriteRequest"); - length = 60 + (count * 2); - buffer = (BYTE*) malloc(length); - if (!buffer) + if (!tsg || !tsg->rpc || !tunnelContext || !tsg->Hostname) + return FALSE; + + rpc = tsg->rpc; + count = _wcslen(tsg->Hostname) + 1; + s = Stream_New(NULL, 60 + count * 2); + + if (!s) return FALSE; /* TunnelContext (20 bytes) */ - CopyMemory(&buffer[0], &tunnelContext->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Write_UINT32(s, tunnelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, tunnelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ /* TSENDPOINTINFO */ - *((UINT32*) &buffer[20]) = 0x00020000; /* ResourceNamePtr (4 bytes) */ - *((UINT32*) &buffer[24]) = 0x00000001; /* NumResourceNames (4 bytes) */ - *((UINT32*) &buffer[28]) = 0x00000000; /* AlternateResourceNamesPtr (4 bytes) */ - *((UINT16*) &buffer[32]) = 0x0000; /* NumAlternateResourceNames (2 bytes) */ - *((UINT16*) &buffer[34]) = 0x0000; /* Pad (2 bytes) */ + Stream_Write_UINT32(s, 0x00020000); /* ResourceNamePtr (4 bytes) */ + Stream_Write_UINT32(s, 0x00000001); /* NumResourceNames (4 bytes) */ + Stream_Write_UINT32(s, 0x00000000); /* AlternateResourceNamesPtr (4 bytes) */ + Stream_Write_UINT16(s, 0x0000); /* NumAlternateResourceNames (2 bytes) */ + Stream_Write_UINT16(s, 0x0000); /* Pad (2 bytes) */ /* Port (4 bytes) */ - *((UINT16*) &buffer[36]) = 0x0003; /* ProtocolId (RDP = 3) (2 bytes) */ - *((UINT16*) &buffer[38]) = tsg->Port; /* PortNumber (0xD3D = 3389) (2 bytes) */ - *((UINT32*) &buffer[40]) = 0x00000001; /* NumResourceNames (4 bytes) */ - *((UINT32*) &buffer[44]) = 0x00020004; /* ResourceNamePtr (4 bytes) */ - *((UINT32*) &buffer[48]) = count; /* MaxCount (4 bytes) */ - *((UINT32*) &buffer[52]) = 0; /* Offset (4 bytes) */ - *((UINT32*) &buffer[56]) = count; /* ActualCount (4 bytes) */ - CopyMemory(&buffer[60], tsg->Hostname, count * 2); /* Array */ - status = rpc_client_write_call(rpc, buffer, length, TsProxyCreateChannelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Write_UINT16(s, 0x0003); /* ProtocolId (RDP = 3) (2 bytes) */ + Stream_Write_UINT16(s, tsg->Port); /* PortNumber (0xD3D = 3389) (2 bytes) */ + Stream_Write_UINT32(s, 0x00000001); /* NumResourceNames (4 bytes) */ + Stream_Write_UINT32(s, 0x00020004); /* ResourceNamePtr (4 bytes) */ + Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */ + Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */ + Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */ + Stream_Write_UTF16_String(s, tsg->Hostname, count); /* Array */ + return rpc_client_write_call(rpc, s, TsProxyCreateChannelOpnum); } static BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* channelContext, UINT32* channelId) { - BYTE* buffer; - UINT32 offset; + BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCreateChannelReadResponse"); if (!pdu) return FALSE; - buffer = Stream_Buffer(pdu->s); + if (Stream_GetRemainingLength(pdu->s) < 28) + goto fail; - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; - - offset = 0; /* ChannelContext (20 bytes) */ - CopyMemory(&channelContext->ContextType, &buffer[offset], 4); /* ContextType (4 bytes) */ - CopyMemory(&channelContext->ContextUuid, &buffer[offset + 4], 16); /* ContextUuid (16 bytes) */ - offset += 20; - *channelId = *((UINT32*) &buffer[offset]); /* ChannelId (4 bytes) */ - /* ReturnValue (4 bytes) */ - return TRUE; + Stream_Read_UINT32(pdu->s, channelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Read(pdu->s, channelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Read_UINT32(pdu->s, *channelId); /* ChannelId (4 bytes) */ + Stream_Seek_UINT32(pdu->s); /* ReturnValue (4 bytes) */ + rc = TRUE; +fail: + return rc; } /** @@ -1085,51 +1237,43 @@ static BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, static BOOL TsProxyCloseChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context) { - int status; - BYTE* buffer; - UINT32 length; - rdpRpc* rpc = tsg->rpc; + wStream* s; + rdpRpc* rpc; WLog_DBG(TAG, "TsProxyCloseChannelWriteRequest"); - if (!context) + if (!tsg || !tsg->rpc || !context) return FALSE; - length = 20; - buffer = (BYTE*) malloc(length); + rpc = tsg->rpc; + s = Stream_New(NULL, 20); - if (!buffer) + if (!s) return FALSE; /* ChannelContext (20 bytes) */ - CopyMemory(&buffer[0], &context->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &context->ContextUuid, 16); /* ContextUuid (16 bytes) */ - status = rpc_client_write_call(rpc, buffer, length, TsProxyCloseChannelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Write_UINT32(s, context->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, context->ContextUuid, 16); /* ContextUuid (16 bytes) */ + return rpc_client_write_call(rpc, s, TsProxyCloseChannelOpnum); } static BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context) { - BYTE* buffer; + BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCloseChannelReadResponse"); if (!pdu) return FALSE; - buffer = Stream_Buffer(pdu->s); - - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; + if (Stream_GetRemainingLength(pdu->s) < 24) + goto fail; /* ChannelContext (20 bytes) */ - CopyMemory(&context->ContextType, &buffer[0], 4); /* ContextType (4 bytes) */ - CopyMemory(&context->ContextUuid, &buffer[4], 16); /* ContextUuid (16 bytes) */ - /* ReturnValue (4 bytes) */ - return TRUE; + Stream_Read_UINT32(pdu->s, context->ContextType); /* ContextType (4 bytes) */ + Stream_Read(pdu->s, context->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Seek_UINT32(pdu->s); /* ReturnValue (4 bytes) */ + rc = TRUE; +fail: + return rc; } /** @@ -1140,47 +1284,43 @@ static BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_H static BOOL TsProxyCloseTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context) { - int status; - BYTE* buffer; - UINT32 length; - rdpRpc* rpc = tsg->rpc; + wStream* s; + rdpRpc* rpc; WLog_DBG(TAG, "TsProxyCloseTunnelWriteRequest"); - length = 20; - buffer = (BYTE*) malloc(length); - if (!buffer) + if (!tsg || !tsg->rpc || !context) + return FALSE; + + rpc = tsg->rpc; + s = Stream_New(NULL, 20); + + if (!s) return FALSE; /* TunnelContext (20 bytes) */ - CopyMemory(&buffer[0], &context->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &context->ContextUuid, 16); /* ContextUuid (16 bytes) */ - status = rpc_client_write_call(rpc, buffer, length, TsProxyCloseTunnelOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Write_UINT32(s, context->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, context->ContextUuid, 16); /* ContextUuid (16 bytes) */ + return rpc_client_write_call(rpc, s, TsProxyCloseTunnelOpnum); } static BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context) { - BYTE* buffer; + BOOL rc = FALSE; WLog_DBG(TAG, "TsProxyCloseTunnelReadResponse"); - if (!pdu) + if (!pdu || !context) return FALSE; - buffer = Stream_Buffer(pdu->s); - - if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) - buffer = &buffer[24]; + if (Stream_GetRemainingLength(pdu->s) < 24) + goto fail; /* TunnelContext (20 bytes) */ - CopyMemory(&context->ContextType, &buffer[0], 4); /* ContextType (4 bytes) */ - CopyMemory(&context->ContextUuid, &buffer[4], 16); /* ContextUuid (16 bytes) */ - /* ReturnValue (4 bytes) */ - return TRUE; + Stream_Read_UINT32(pdu->s, context->ContextType); /* ContextType (4 bytes) */ + Stream_Read(pdu->s, context->ContextUuid, 16); /* ContextUuid (16 bytes) */ + Stream_Seek_UINT32(pdu->s); /* ReturnValue (4 bytes) */ + rc = TRUE; +fail: + return rc; } /** @@ -1193,31 +1333,27 @@ static BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HA static BOOL TsProxySetupReceivePipeWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* channelContext) { - int status; - BYTE* buffer; - UINT32 length; - rdpRpc* rpc = tsg->rpc; + wStream* s; + rdpRpc* rpc; WLog_DBG(TAG, "TsProxySetupReceivePipeWriteRequest"); - length = 20; - buffer = (BYTE*) malloc(length); - if (!buffer) + if (!tsg || !tsg->rpc || !channelContext) + return FALSE; + + rpc = tsg->rpc; + s = Stream_New(NULL, 20); + + if (!s) return FALSE; /* ChannelContext (20 bytes) */ - CopyMemory(&buffer[0], &channelContext->ContextType, 4); /* ContextType (4 bytes) */ - CopyMemory(&buffer[4], &channelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ - status = rpc_client_write_call(rpc, buffer, length, TsProxySetupReceivePipeOpnum); - free(buffer); - - if (status <= 0) - return FALSE; - - return TRUE; + Stream_Write_UINT32(s, channelContext->ContextType); /* ContextType (4 bytes) */ + Stream_Write(s, channelContext->ContextUuid, 16); /* ContextUuid (16 bytes) */ + return rpc_client_write_call(rpc, s, TsProxySetupReceivePipeOpnum); } -static int tsg_transition_to_state(rdpTsg* tsg, TSG_STATE state) +static BOOL tsg_transition_to_state(rdpTsg* tsg, TSG_STATE state) { const char* str = "TSG_STATE_UNKNOWN"; @@ -1256,16 +1392,19 @@ static int tsg_transition_to_state(rdpTsg* tsg, TSG_STATE state) break; } - tsg->state = state; WLog_DBG(TAG, "%s", str); - return 1; + return tsg_set_state(tsg, state); } -int tsg_proxy_begin(rdpTsg* tsg) +BOOL tsg_proxy_begin(rdpTsg* tsg) { TSG_PACKET tsgPacket; PTSG_CAPABILITY_NAP tsgCapNap; PTSG_PACKET_VERSIONCAPS packetVersionCaps; + + if (!tsg) + return FALSE; + packetVersionCaps = &tsg->packetVersionCaps; packetVersionCaps->tsgCaps = &tsg->tsgCaps; tsgCapNap = &tsg->tsgCaps.tsgPacket.tsgCapNap; @@ -1295,22 +1434,29 @@ int tsg_proxy_begin(rdpTsg* tsg) if (!TsProxyCreateTunnelWriteRequest(tsg, &tsgPacket)) { WLog_ERR(TAG, "TsProxyCreateTunnel failure"); - tsg->state = TSG_STATE_FINAL; - return -1; + tsg_transition_to_state(tsg, TSG_STATE_FINAL); + return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_INITIAL); - return 1; + return tsg_transition_to_state(tsg, TSG_STATE_INITIAL); } -static int tsg_proxy_reauth(rdpTsg* tsg) +static BOOL tsg_proxy_reauth(rdpTsg* tsg) { TSG_PACKET tsgPacket; PTSG_PACKET_REAUTH packetReauth; PTSG_PACKET_VERSIONCAPS packetVersionCaps; + + if (!tsg) + return FALSE; + tsg->reauthSequence = TRUE; packetReauth = &tsg->packetReauth; packetVersionCaps = &tsg->packetVersionCaps; + + if (!packetReauth || !packetVersionCaps) + return FALSE; + tsgPacket.packetId = TSG_PACKET_TYPE_REAUTH; tsgPacket.tsgPacket.packetReauth = &tsg->packetReauth; packetReauth->tunnelContext = tsg->ReauthTunnelContext; @@ -1320,26 +1466,38 @@ static int tsg_proxy_reauth(rdpTsg* tsg) if (!TsProxyCreateTunnelWriteRequest(tsg, &tsgPacket)) { WLog_ERR(TAG, "TsProxyCreateTunnel failure"); - tsg->state = TSG_STATE_FINAL; - return -1; + tsg_transition_to_state(tsg, TSG_STATE_FINAL); + return FALSE; } if (!TsProxyMakeTunnelCallWriteRequest(tsg, &tsg->TunnelContext, TSG_TUNNEL_CALL_ASYNC_MSG_REQUEST)) { WLog_ERR(TAG, "TsProxyMakeTunnelCall failure"); - tsg->state = TSG_STATE_FINAL; - return -1; + tsg_transition_to_state(tsg, TSG_STATE_FINAL); + return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_INITIAL); - return 1; + return tsg_transition_to_state(tsg, TSG_STATE_INITIAL); } -int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) +BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) { - int status = -1; + BOOL rc = FALSE; RpcClientCall* call; - rdpRpc* rpc = tsg->rpc; + rdpRpc* rpc; + + if (!tsg || !tsg->rpc || !pdu) + return FALSE; + + rpc = tsg->rpc; + Stream_SealLength(pdu->s); + Stream_SetPosition(pdu->s, 0); + + if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) + { + if (!Stream_SafeSeek(pdu->s, 24)) + return FALSE; + } switch (tsg->state) { @@ -1351,18 +1509,19 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyCreateTunnelReadResponse(tsg, pdu, TunnelContext, &tsg->TunnelId)) { WLog_ERR(TAG, "TsProxyCreateTunnelReadResponse failure"); - return -1; + return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_CONNECTED); + if (!tsg_transition_to_state(tsg, TSG_STATE_CONNECTED)) + return FALSE; if (!TsProxyAuthorizeTunnelWriteRequest(tsg, TunnelContext)) { WLog_ERR(TAG, "TsProxyAuthorizeTunnel failure"); - return -1; + return FALSE; } - status = 1; + rc = TRUE; } break; @@ -1374,27 +1533,28 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyAuthorizeTunnelReadResponse(tsg, pdu)) { WLog_ERR(TAG, "TsProxyAuthorizeTunnelReadResponse failure"); - return -1; + return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_AUTHORIZED); + if (!tsg_transition_to_state(tsg, TSG_STATE_AUTHORIZED)) + return FALSE; if (!tsg->reauthSequence) { if (!TsProxyMakeTunnelCallWriteRequest(tsg, TunnelContext, TSG_TUNNEL_CALL_ASYNC_MSG_REQUEST)) { WLog_ERR(TAG, "TsProxyMakeTunnelCall failure"); - return -1; + return FALSE; } } if (!TsProxyCreateChannelWriteRequest(tsg, TunnelContext)) { WLog_ERR(TAG, "TsProxyCreateChannel failure"); - return -1; + return FALSE; } - status = 1; + rc = TRUE; } break; @@ -1402,17 +1562,17 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) call = rpc_client_call_find_by_id(rpc->client, pdu->CallId); if (!call) - return -1; + return FALSE; if (call->OpNum == TsProxyMakeTunnelCallOpnum) { if (!TsProxyMakeTunnelCallReadResponse(tsg, pdu)) { WLog_ERR(TAG, "TsProxyMakeTunnelCallReadResponse failure"); - return -1; + return FALSE; } - status = 1; + rc = TRUE; } else if (call->OpNum == TsProxyCreateChannelOpnum) { @@ -1421,7 +1581,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyCreateChannelReadResponse(tsg, pdu, &ChannelContext, &tsg->ChannelId)) { WLog_ERR(TAG, "TsProxyCreateChannelReadResponse failure"); - return -1; + return FALSE; } if (!tsg->reauthSequence) @@ -1429,14 +1589,15 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) else CopyMemory(&tsg->NewChannelContext, &ChannelContext, sizeof(CONTEXT_HANDLE)); - tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CREATED); + if (!tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CREATED)) + return FALSE; if (!tsg->reauthSequence) { if (!TsProxySetupReceivePipeWriteRequest(tsg, &tsg->ChannelContext)) { WLog_ERR(TAG, "TsProxySetupReceivePipe failure"); - return -1; + return FALSE; } } else @@ -1444,19 +1605,18 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyCloseChannelWriteRequest(tsg, &tsg->NewChannelContext)) { WLog_ERR(TAG, "TsProxyCloseChannelWriteRequest failure"); - return -1; + return FALSE; } if (!TsProxyCloseTunnelWriteRequest(tsg, &tsg->NewTunnelContext)) { WLog_ERR(TAG, "TsProxyCloseTunnelWriteRequest failure"); - return -1; + return FALSE; } } - tsg_transition_to_state(tsg, TSG_STATE_PIPE_CREATED); + rc = tsg_transition_to_state(tsg, TSG_STATE_PIPE_CREATED); tsg->reauthSequence = FALSE; - status = 1; } else { @@ -1472,20 +1632,20 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) call = rpc_client_call_find_by_id(rpc->client, pdu->CallId); if (!call) - return -1; + return FALSE; if (call->OpNum == TsProxyMakeTunnelCallOpnum) { if (!TsProxyMakeTunnelCallReadResponse(tsg, pdu)) { WLog_ERR(TAG, "TsProxyMakeTunnelCallReadResponse failure"); - return -1; + return FALSE; } - if (tsg->ReauthTunnelContext) - tsg_proxy_reauth(tsg); + rc = TRUE; - status = 1; + if (tsg->ReauthTunnelContext) + rc = tsg_proxy_reauth(tsg); } else if (call->OpNum == TsProxyCloseChannelOpnum) { @@ -1494,10 +1654,10 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyCloseChannelReadResponse(tsg, pdu, &ChannelContext)) { WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure"); - return -1; + return FALSE; } - status = 1; + rc = TRUE; } else if (call->OpNum == TsProxyCloseTunnelOpnum) { @@ -1506,10 +1666,10 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) if (!TsProxyCloseTunnelReadResponse(tsg, pdu, &TunnelContext)) { WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure"); - return -1; + return FALSE; } - status = 1; + rc = TRUE; } break; @@ -1524,7 +1684,8 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CLOSE_PENDING); + if (!tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CLOSE_PENDING)) + return FALSE; if (!TsProxyCloseChannelWriteRequest(tsg, NULL)) { @@ -1539,7 +1700,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) return FALSE; } - status = 1; + rc = TRUE; } break; @@ -1553,8 +1714,7 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) return FALSE; } - tsg_transition_to_state(tsg, TSG_STATE_FINAL); - status = 1; + rc = tsg_transition_to_state(tsg, TSG_STATE_FINAL); } break; @@ -1562,23 +1722,18 @@ int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu) break; } - return status; + return rc; } -int tsg_check_event_handles(rdpTsg* tsg) +BOOL tsg_check_event_handles(rdpTsg* tsg) { - int status; - status = rpc_client_in_channel_recv(tsg->rpc); + if (!rpc_client_in_channel_recv(tsg->rpc)) + return FALSE; - if (status < 0) - return -1; + if (!rpc_client_out_channel_recv(tsg->rpc)) + return FALSE; - status = rpc_client_out_channel_recv(tsg->rpc); - - if (status < 0) - return -1; - - return status; + return TRUE; } DWORD tsg_get_event_handles(rdpTsg* tsg, HANDLE* events, DWORD count) @@ -1671,8 +1826,11 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout) if (!settings->GatewayPort) settings->GatewayPort = 443; - tsg_set_hostname(tsg, hostname); - tsg_set_machine_name(tsg, settings->ComputerName); + if (!tsg_set_hostname(tsg, hostname)) + return FALSE; + + if (!tsg_set_machine_name(tsg, settings->ComputerName)) + return FALSE; if (!rpc_connect(rpc, timeout)) { @@ -1689,7 +1847,7 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout) { WaitForMultipleObjects(nCount, events, FALSE, 250); - if (tsg_check_event_handles(tsg) < 0) + if (!tsg_check_event_handles(tsg)) { WLog_ERR(TAG, "tsg_check failure"); transport->layer = TRANSPORT_LAYER_CLOSED; @@ -1736,7 +1894,7 @@ BOOL tsg_disconnect(rdpTsg* tsg) if (!TsProxyCloseChannelWriteRequest(tsg, &tsg->ChannelContext)) return FALSE; - tsg->state = TSG_STATE_CHANNEL_CLOSE_PENDING; + return tsg_transition_to_state(tsg, TSG_STATE_CHANNEL_CLOSE_PENDING); } return TRUE; @@ -1756,7 +1914,7 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) rdpRpc* rpc; int status = 0; - if (!tsg) + if (!tsg || !data) return -1; rpc = tsg->rpc; @@ -1790,7 +1948,7 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) { while (WaitForSingleObject(rpc->client->PipeEvent, 0) != WAIT_OBJECT_0) { - if (tsg_check_event_handles(tsg) < 0) + if (!tsg_check_event_handles(tsg)) return -1; WaitForSingleObject(rpc->client->PipeEvent, 100); @@ -1802,10 +1960,13 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) return status; } -static int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length) +static int tsg_write(rdpTsg* tsg, const BYTE* data, UINT32 length) { int status; + if (!tsg || !data || !tsg->rpc || !tsg->rpc->transport) + return -1; + if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED) { WLog_ERR(TAG, "error, connection lost"); @@ -1845,19 +2006,13 @@ void tsg_free(rdpTsg* tsg) { if (tsg) { - if (tsg->rpc) - { - rpc_free(tsg->rpc); - tsg->rpc = NULL; - } - + rpc_free(tsg->rpc); free(tsg->Hostname); free(tsg->MachineName); free(tsg); } } - static int transport_bio_tsg_write(BIO* bio, const char* buf, int num) { int status; @@ -1887,6 +2042,13 @@ static int transport_bio_tsg_read(BIO* bio, char* buf, int size) { int status; rdpTsg* tsg = (rdpTsg*) BIO_get_data(bio); + + if (!tsg || (size < 0)) + { + BIO_clear_flags(bio, BIO_FLAGS_SHOULD_RETRY); + return -1; + } + BIO_clear_flags(bio, BIO_FLAGS_READ); status = tsg_read(tsg, (BYTE*) buf, size); @@ -1926,57 +2088,71 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2) RpcInChannel* inChannel = connection->DefaultInChannel; RpcOutChannel* outChannel = connection->DefaultOutChannel; - if (cmd == BIO_CTRL_FLUSH) + switch (cmd) { - (void)BIO_flush(inChannel->common.tls->bio); - (void)BIO_flush(outChannel->common.tls->bio); - status = 1; - } - else if (cmd == BIO_C_GET_EVENT) - { - if (arg2) - { - *((HANDLE*) arg2) = tsg->rpc->client->PipeEvent; + case BIO_CTRL_FLUSH: + (void)BIO_flush(inChannel->common.tls->bio); + (void)BIO_flush(outChannel->common.tls->bio); status = 1; - } - } - else if (cmd == BIO_C_SET_NONBLOCK) - { - status = 1; - } - else if (cmd == BIO_C_READ_BLOCKED) - { - BIO* bio = outChannel->common.bio; - status = BIO_read_blocked(bio); - } - else if (cmd == BIO_C_WRITE_BLOCKED) - { - BIO* bio = inChannel->common.bio; - status = BIO_write_blocked(bio); - } - else if (cmd == BIO_C_WAIT_READ) - { - int timeout = (int) arg1; - BIO* bio = outChannel->common.bio; + break; - if (BIO_read_blocked(bio)) - return BIO_wait_read(bio, timeout); - else if (BIO_write_blocked(bio)) - return BIO_wait_write(bio, timeout); - else - status = 1; - } - else if (cmd == BIO_C_WAIT_WRITE) - { - int timeout = (int) arg1; - BIO* bio = inChannel->common.bio; + case BIO_C_GET_EVENT: + if (arg2) + { + *((HANDLE*) arg2) = tsg->rpc->client->PipeEvent; + status = 1; + } - if (BIO_write_blocked(bio)) - status = BIO_wait_write(bio, timeout); - else if (BIO_read_blocked(bio)) - status = BIO_wait_read(bio, timeout); - else + break; + + case BIO_C_SET_NONBLOCK: status = 1; + break; + + case BIO_C_READ_BLOCKED: + { + BIO* bio = outChannel->common.bio; + status = BIO_read_blocked(bio); + } + break; + + case BIO_C_WRITE_BLOCKED: + { + BIO* bio = inChannel->common.bio; + status = BIO_write_blocked(bio); + } + break; + + case BIO_C_WAIT_READ: + { + int timeout = (int) arg1; + BIO* bio = outChannel->common.bio; + + if (BIO_read_blocked(bio)) + return BIO_wait_read(bio, timeout); + else if (BIO_write_blocked(bio)) + return BIO_wait_write(bio, timeout); + else + status = 1; + } + break; + + case BIO_C_WAIT_WRITE: + { + int timeout = (int) arg1; + BIO* bio = inChannel->common.bio; + + if (BIO_write_blocked(bio)) + status = BIO_wait_write(bio, timeout); + else if (BIO_read_blocked(bio)) + status = BIO_wait_read(bio, timeout); + else + status = 1; + } + break; + + default: + break; } return status; @@ -2014,3 +2190,28 @@ BIO_METHOD* BIO_s_tsg(void) return bio_methods; } + +TSG_STATE tsg_get_state(rdpTsg* tsg) +{ + if (!tsg) + return TSG_STATE_INITIAL; + + return tsg->state; +} + +BIO* tsg_get_bio(rdpTsg* tsg) +{ + if (!tsg) + return NULL; + + return tsg->bio; +} + +BOOL tsg_set_state(rdpTsg* tsg, TSG_STATE state) +{ + if (!tsg) + return FALSE; + + tsg->state = state; + return TRUE; +} diff --git a/libfreerdp/core/gateway/tsg.h b/libfreerdp/core/gateway/tsg.h index ac0fb281b..cef0df98a 100644 --- a/libfreerdp/core/gateway/tsg.h +++ b/libfreerdp/core/gateway/tsg.h @@ -31,16 +31,10 @@ typedef struct rdp_tsg rdpTsg; #include #include #include -#include -#include -#include #include -#include #include -#include - enum _TSG_STATE { TSG_STATE_INITIAL, @@ -54,8 +48,6 @@ enum _TSG_STATE }; typedef enum _TSG_STATE TSG_STATE; -typedef WCHAR* RESOURCENAME; - #define TsProxyCreateTunnelOpnum 1 #define TsProxyAuthorizeTunnelOpnum 2 #define TsProxyMakeTunnelCallOpnum 3 @@ -68,15 +60,6 @@ typedef WCHAR* RESOURCENAME; #define MAX_RESOURCE_NAMES 50 -typedef struct _tsendpointinfo -{ - RESOURCENAME* resourceName; - UINT32 numResourceNames; - RESOURCENAME* alternateResourceNames; - UINT16 numAlternateResourceNames; - UINT32 Port; -} TSENDPOINTINFO, *PTSENDPOINTINFO; - #define TS_GATEWAY_TRANSPORT 0x5452 #define TSG_PACKET_TYPE_HEADER 0x00004844 @@ -130,198 +113,24 @@ typedef struct _tsendpointinfo #define E_PROXY_REAUTH_NAP_FAILED 0x00005A00 #define E_PROXY_CONNECTIONABORTED 0x000004D4 -typedef struct _TSG_PACKET_HEADER -{ - UINT16 ComponentId; - UINT16 PacketId; -} TSG_PACKET_HEADER, *PTSG_PACKET_HEADER; +FREERDP_LOCAL rdpTsg* tsg_new(rdpTransport* transport); +FREERDP_LOCAL void tsg_free(rdpTsg* tsg); -typedef struct _TSG_CAPABILITY_NAP -{ - UINT32 capabilities; -} TSG_CAPABILITY_NAP, *PTSG_CAPABILITY_NAP; - -typedef union -{ - TSG_CAPABILITY_NAP tsgCapNap; -} TSG_CAPABILITIES_UNION, *PTSG_CAPABILITIES_UNION; - -typedef struct _TSG_PACKET_CAPABILITIES -{ - UINT32 capabilityType; - TSG_CAPABILITIES_UNION tsgPacket; -} TSG_PACKET_CAPABILITIES, *PTSG_PACKET_CAPABILITIES; - -typedef struct _TSG_PACKET_VERSIONCAPS -{ - TSG_PACKET_HEADER tsgHeader; - PTSG_PACKET_CAPABILITIES tsgCaps; - UINT32 numCapabilities; - UINT16 majorVersion; - UINT16 minorVersion; - UINT16 quarantineCapabilities; -} TSG_PACKET_VERSIONCAPS, *PTSG_PACKET_VERSIONCAPS; - -typedef struct _TSG_PACKET_QUARCONFIGREQUEST -{ - UINT32 flags; -} TSG_PACKET_QUARCONFIGREQUEST, *PTSG_PACKET_QUARCONFIGREQUEST; - -typedef struct _TSG_PACKET_QUARREQUEST -{ - UINT32 flags; - WCHAR* machineName; - UINT32 nameLength; - BYTE* data; - UINT32 dataLen; -} TSG_PACKET_QUARREQUEST, *PTSG_PACKET_QUARREQUEST; - -typedef struct _TSG_REDIRECTION_FLAGS -{ - BOOL enableAllRedirections; - BOOL disableAllRedirections; - BOOL driveRedirectionDisabled; - BOOL printerRedirectionDisabled; - BOOL portRedirectionDisabled; - BOOL reserved; - BOOL clipboardRedirectionDisabled; - BOOL pnpRedirectionDisabled; -} TSG_REDIRECTION_FLAGS, *PTSG_REDIRECTION_FLAGS; - -typedef struct _TSG_PACKET_RESPONSE -{ - UINT32 flags; - UINT32 reserved; - BYTE* responseData; - UINT32 responseDataLen; - TSG_REDIRECTION_FLAGS redirectionFlags; -} TSG_PACKET_RESPONSE, *PTSG_PACKET_RESPONSE; - -typedef struct _TSG_PACKET_QUARENC_RESPONSE -{ - UINT32 flags; - UINT32 certChainLen; - WCHAR* certChainData; - GUID nonce; - PTSG_PACKET_VERSIONCAPS versionCaps; -} TSG_PACKET_QUARENC_RESPONSE, *PTSG_PACKET_QUARENC_RESPONSE; - -typedef struct TSG_PACKET_STRING_MESSAGE -{ - INT32 isDisplayMandatory; - INT32 isConsentMandatory; - UINT32 msgBytes; - WCHAR* msgBuffer; -} TSG_PACKET_STRING_MESSAGE, *PTSG_PACKET_STRING_MESSAGE; - -typedef struct TSG_PACKET_REAUTH_MESSAGE -{ - UINT64 tunnelContext; -} TSG_PACKET_REAUTH_MESSAGE, *PTSG_PACKET_REAUTH_MESSAGE; - -typedef union -{ - PTSG_PACKET_STRING_MESSAGE consentMessage; - PTSG_PACKET_STRING_MESSAGE serviceMessage; - PTSG_PACKET_REAUTH_MESSAGE reauthMessage; -} TSG_PACKET_TYPE_MESSAGE_UNION, *PTSG_PACKET_TYPE_MESSAGE_UNION; - -typedef struct _TSG_PACKET_MSG_RESPONSE -{ - UINT32 msgID; - UINT32 msgType; - INT32 isMsgPresent; - TSG_PACKET_TYPE_MESSAGE_UNION messagePacket; -} TSG_PACKET_MSG_RESPONSE, *PTSG_PACKET_MSG_RESPONSE; - -typedef struct TSG_PACKET_CAPS_RESPONSE -{ - TSG_PACKET_QUARENC_RESPONSE pktQuarEncResponse; - TSG_PACKET_MSG_RESPONSE pktConsentMessage; -} TSG_PACKET_CAPS_RESPONSE, *PTSG_PACKET_CAPS_RESPONSE; - -typedef struct TSG_PACKET_MSG_REQUEST -{ - UINT32 maxMessagesPerBatch; -} TSG_PACKET_MSG_REQUEST, *PTSG_PACKET_MSG_REQUEST; - -typedef struct _TSG_PACKET_AUTH -{ - TSG_PACKET_VERSIONCAPS tsgVersionCaps; - UINT32 cookieLen; - BYTE* cookie; -} TSG_PACKET_AUTH, *PTSG_PACKET_AUTH; - -typedef union -{ - PTSG_PACKET_VERSIONCAPS packetVersionCaps; - PTSG_PACKET_AUTH packetAuth; -} TSG_INITIAL_PACKET_TYPE_UNION, *PTSG_INITIAL_PACKET_TYPE_UNION; - -typedef struct TSG_PACKET_REAUTH -{ - UINT64 tunnelContext; - UINT32 packetId; - TSG_INITIAL_PACKET_TYPE_UNION tsgInitialPacket; -} TSG_PACKET_REAUTH, *PTSG_PACKET_REAUTH; - -typedef union -{ - PTSG_PACKET_HEADER packetHeader; - PTSG_PACKET_VERSIONCAPS packetVersionCaps; - PTSG_PACKET_QUARCONFIGREQUEST packetQuarConfigRequest; - PTSG_PACKET_QUARREQUEST packetQuarRequest; - PTSG_PACKET_RESPONSE packetResponse; - PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse; - PTSG_PACKET_CAPS_RESPONSE packetCapsResponse; - PTSG_PACKET_MSG_REQUEST packetMsgRequest; - PTSG_PACKET_MSG_RESPONSE packetMsgResponse; - PTSG_PACKET_AUTH packetAuth; - PTSG_PACKET_REAUTH packetReauth; -} TSG_PACKET_TYPE_UNION; - -typedef struct _TSG_PACKET -{ - UINT32 packetId; - TSG_PACKET_TYPE_UNION tsgPacket; -} TSG_PACKET, *PTSG_PACKET; - -struct rdp_tsg -{ - BIO* bio; - rdpRpc* rpc; - UINT16 Port; - LPWSTR Hostname; - LPWSTR MachineName; - TSG_STATE state; - UINT32 TunnelId; - UINT32 ChannelId; - BOOL reauthSequence; - rdpSettings* settings; - rdpTransport* transport; - UINT64 ReauthTunnelContext; - CONTEXT_HANDLE TunnelContext; - CONTEXT_HANDLE ChannelContext; - CONTEXT_HANDLE NewTunnelContext; - CONTEXT_HANDLE NewChannelContext; - TSG_PACKET_REAUTH packetReauth; - TSG_PACKET_CAPABILITIES tsgCaps; - TSG_PACKET_VERSIONCAPS packetVersionCaps; -}; - -FREERDP_LOCAL int tsg_proxy_begin(rdpTsg* tsg); +FREERDP_LOCAL BOOL tsg_proxy_begin(rdpTsg* tsg); FREERDP_LOCAL BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout); FREERDP_LOCAL BOOL tsg_disconnect(rdpTsg* tsg); -FREERDP_LOCAL int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu); +FREERDP_LOCAL BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu); -FREERDP_LOCAL int tsg_check_event_handles(rdpTsg* tsg); +FREERDP_LOCAL BOOL tsg_check_event_handles(rdpTsg* tsg); FREERDP_LOCAL DWORD tsg_get_event_handles(rdpTsg* tsg, HANDLE* events, DWORD count); -FREERDP_LOCAL rdpTsg* tsg_new(rdpTransport* transport); -FREERDP_LOCAL void tsg_free(rdpTsg* tsg); +FREERDP_LOCAL TSG_STATE tsg_get_state(rdpTsg* tsg); +FREERDP_LOCAL BOOL tsg_set_state(rdpTsg* tsg, TSG_STATE state); + +FREERDP_LOCAL BIO* tsg_get_bio(rdpTsg* tsg); #endif /* FREERDP_LIB_CORE_GATEWAY_TSG_H */ diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index d9c8bcea7..0c8ea678d 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1513,16 +1513,15 @@ int rdp_check_fds(rdpRdp* rdp) if (transport->tsg) { rdpTsg* tsg = transport->tsg; - status = tsg_check_event_handles(tsg); - if (status < 0) + if (!tsg_check_event_handles(tsg)) { - WLog_ERR(TAG, "rdp_check_fds: tsg_check_event_handles() - %i", status); + WLog_ERR(TAG, "rdp_check_fds: tsg_check_event_handles()"); return -1; } - if (tsg->state != TSG_STATE_PIPE_CREATED) - return status; + if (tsg_get_state(tsg) != TSG_STATE_PIPE_CREATED) + return 1; } status = transport_check_fds(transport); diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index dd3c3e01c..bdcf77a66 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -393,7 +393,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, if (status) { - transport->frontBio = transport->tsg->bio; + transport->frontBio = tsg_get_bio(transport->tsg); transport->layer = TRANSPORT_LAYER_TSG; status = TRUE; } From 99eb9f7ec9575e678a4c021e766614ff20399367 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Fri, 28 Sep 2018 12:08:27 +0200 Subject: [PATCH 08/13] Refactored and simplified RPC signature functions. --- libfreerdp/core/gateway/rpc_client.c | 4 +- libfreerdp/core/gateway/rts.c | 45 +++++--- libfreerdp/core/gateway/rts.h | 3 +- libfreerdp/core/gateway/rts_signature.c | 143 +++++++++++++----------- libfreerdp/core/gateway/rts_signature.h | 117 ++++++++++--------- 5 files changed, 166 insertions(+), 146 deletions(-) diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index af814e093..c8dc8f21b 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -189,7 +189,7 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) case VIRTUAL_CONNECTION_STATE_WAIT_A3W: rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s); - if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts)) + if (!rts_match_pdu_signature(&RTS_PDU_CONN_A3_SIGNATURE, rts)) { WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/A3"); return -1; @@ -211,7 +211,7 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu) case VIRTUAL_CONNECTION_STATE_WAIT_C2: rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s); - if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts)) + if (!rts_match_pdu_signature(&RTS_PDU_CONN_C2_SIGNATURE, rts)) { WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/C2"); return -1; diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index a0df68962..5e7ecfae4 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -240,7 +240,7 @@ static int rts_empty_command_write(BYTE* buffer) return 4; } -static int rts_padding_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static SSIZE_T rts_padding_command_read(const BYTE* buffer, size_t length) { UINT32 ConformanceCount; ConformanceCount = *((UINT32*) &buffer[0]); /* ConformanceCount (4 bytes) */ @@ -290,7 +290,7 @@ static int rts_ance_command_write(BYTE* buffer) return 4; } -static int rts_client_address_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length) +static SSIZE_T rts_client_address_command_read(const BYTE* buffer, size_t length) { UINT32 AddressType; AddressType = *((UINT32*) &buffer[0]); /* AddressType (4 bytes) */ @@ -647,7 +647,7 @@ static int rts_send_ping_pdu(rdpRpc* rpc) return (status > 0) ? 1 : -1; } -int rts_command_length(rdpRpc* rpc, UINT32 CommandType, BYTE* buffer, UINT32 length) +SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length) { int CommandLength = 0; @@ -686,7 +686,7 @@ int rts_command_length(rdpRpc* rpc, UINT32 CommandType, BYTE* buffer, UINT32 len break; case RTS_CMD_PADDING: /* variable-size */ - CommandLength = rts_padding_command_read(rpc, buffer, length); + CommandLength = rts_padding_command_read(buffer, length); break; case RTS_CMD_NEGATIVE_ANCE: @@ -698,7 +698,7 @@ int rts_command_length(rdpRpc* rpc, UINT32 CommandType, BYTE* buffer, UINT32 len break; case RTS_CMD_CLIENT_ADDRESS: /* variable-size */ - CommandLength = rts_client_address_command_read(rpc, buffer, length); + CommandLength = rts_client_address_command_read(buffer, length); break; case RTS_CMD_ASSOCIATION_GROUP_ID: @@ -716,7 +716,6 @@ int rts_command_length(rdpRpc* rpc, UINT32 CommandType, BYTE* buffer, UINT32 len default: WLog_ERR(TAG, "Error: Unknown RTS Command Type: 0x%"PRIx32"", CommandType); return -1; - break; } return CommandLength; @@ -889,20 +888,32 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) UINT32 SignatureId; rpcconn_rts_hdr_t* rts; RtsPduSignature signature; - RpcVirtualConnection* connection = rpc->VirtualConnection; - rts = (rpcconn_rts_hdr_t*) buffer; - rts_extract_pdu_signature(rpc, &signature, rts); - SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL); + RpcVirtualConnection* connection; - if (rts_match_pdu_signature(rpc, &RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, rts)) + if (!rpc || !buffer) + return -1; + + connection = rpc->VirtualConnection; + + if (!connection) + return -1; + + rts = (rpcconn_rts_hdr_t*) buffer; + + if (!rts_extract_pdu_signature(&signature, rts)) + return -1; + + SignatureId = rts_identify_pdu_signature(&signature, NULL); + + if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, rts)) { status = rts_recv_flow_control_ack_pdu(rpc, buffer, length); } - else if (rts_match_pdu_signature(rpc, &RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, rts)) + else if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, rts)) { status = rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); } - else if (rts_match_pdu_signature(rpc, &RTS_PDU_PING_SIGNATURE, rts)) + else if (rts_match_pdu_signature(&RTS_PDU_PING_SIGNATURE, rts)) { status = rts_send_ping_pdu(rpc); } @@ -910,21 +921,21 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) { if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED) { - if (rts_match_pdu_signature(rpc, &RTS_PDU_OUT_R1_A2_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R1_A2_SIGNATURE, rts)) { status = rts_recv_OUT_R1_A2_pdu(rpc, buffer, length); } } else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_A6W) { - if (rts_match_pdu_signature(rpc, &RTS_PDU_OUT_R2_A6_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_A6_SIGNATURE, rts)) { status = rts_recv_OUT_R2_A6_pdu(rpc, buffer, length); } } else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_B3W) { - if (rts_match_pdu_signature(rpc, &RTS_PDU_OUT_R2_B3_SIGNATURE, rts)) + if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_B3_SIGNATURE, rts)) { status = rts_recv_OUT_R2_B3_pdu(rpc, buffer, length); } @@ -934,7 +945,7 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) if (status < 0) { WLog_ERR(TAG, "error parsing RTS PDU with signature id: 0x%08"PRIX32"", SignatureId); - rts_print_pdu_signature(rpc, &signature); + rts_print_pdu_signature(&signature); } return status; diff --git a/libfreerdp/core/gateway/rts.h b/libfreerdp/core/gateway/rts.h index c40a5c4df..67ebf7f8f 100644 --- a/libfreerdp/core/gateway/rts.h +++ b/libfreerdp/core/gateway/rts.h @@ -79,8 +79,7 @@ FREERDP_LOCAL void rts_generate_cookie(BYTE* cookie); -FREERDP_LOCAL int rts_command_length(rdpRpc* rpc, UINT32 CommandType, - BYTE* buffer, UINT32 length); +FREERDP_LOCAL SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length); FREERDP_LOCAL int rts_send_CONN_A1_pdu(rdpRpc* rpc); FREERDP_LOCAL int rts_recv_CONN_A3_pdu(rdpRpc* rpc, BYTE* buffer, diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c index a0fe3b42f..365e8844f 100644 --- a/libfreerdp/core/gateway/rts_signature.c +++ b/libfreerdp/core/gateway/rts_signature.c @@ -23,191 +23,191 @@ #define TAG FREERDP_TAG("core.gateway.rts") -RtsPduSignature RTS_PDU_CONN_A1_SIGNATURE = { RTS_FLAG_NONE, 4, +const RtsPduSignature RTS_PDU_CONN_A1_SIGNATURE = { RTS_FLAG_NONE, 4, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_A2_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 5, +const RtsPduSignature RTS_PDU_CONN_A2_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 5, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_CHANNEL_LIFETIME, RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_A3_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_CONN_A3_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_B1_SIGNATURE = { RTS_FLAG_NONE, 6, +const RtsPduSignature RTS_PDU_CONN_B1_SIGNATURE = { RTS_FLAG_NONE, 6, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_CHANNEL_LIFETIME, RTS_CMD_CLIENT_KEEPALIVE, RTS_CMD_ASSOCIATION_GROUP_ID, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_B2_SIGNATURE = { RTS_FLAG_IN_CHANNEL, 7, +const RtsPduSignature RTS_PDU_CONN_B2_SIGNATURE = { RTS_FLAG_IN_CHANNEL, 7, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, RTS_CMD_ASSOCIATION_GROUP_ID, RTS_CMD_CLIENT_ADDRESS, 0 } }; -RtsPduSignature RTS_PDU_CONN_B3_SIGNATURE = { RTS_FLAG_NONE, 2, +const RtsPduSignature RTS_PDU_CONN_B3_SIGNATURE = { RTS_FLAG_NONE, 2, { RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_VERSION, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_C1_SIGNATURE = { RTS_FLAG_NONE, 3, +const RtsPduSignature RTS_PDU_CONN_C1_SIGNATURE = { RTS_FLAG_NONE, 3, { RTS_CMD_VERSION, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_CONN_C2_SIGNATURE = { RTS_FLAG_NONE, 3, +const RtsPduSignature RTS_PDU_CONN_C2_SIGNATURE = { RTS_FLAG_NONE, 3, { RTS_CMD_VERSION, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 4, +const RtsPduSignature RTS_PDU_IN_R1_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 4, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A2_SIGNATURE = { RTS_FLAG_NONE, 4, +const RtsPduSignature RTS_PDU_IN_R1_A2_SIGNATURE = { RTS_FLAG_NONE, 4, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A3_SIGNATURE = { RTS_FLAG_NONE, 4, +const RtsPduSignature RTS_PDU_IN_R1_A3_SIGNATURE = { RTS_FLAG_NONE, 4, { RTS_CMD_DESTINATION, RTS_CMD_VERSION, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A4_SIGNATURE = { RTS_FLAG_NONE, 4, +const RtsPduSignature RTS_PDU_IN_R1_A4_SIGNATURE = { RTS_FLAG_NONE, 4, { RTS_CMD_DESTINATION, RTS_CMD_VERSION, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A5_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R1_A5_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_A6_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R1_A6_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_B1_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R1_B1_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_EMPTY, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R1_B2_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R1_B2_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R2_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 4, +const RtsPduSignature RTS_PDU_IN_R2_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 4, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R2_A2_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R2_A2_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R2_A3_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R2_A3_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R2_A4_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R2_A4_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_IN_R2_A5_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_IN_R2_A5_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, +const RtsPduSignature RTS_PDU_OUT_R1_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A2_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, +const RtsPduSignature RTS_PDU_OUT_R1_A2_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A3_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 5, +const RtsPduSignature RTS_PDU_OUT_R1_A3_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 5, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A4_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL | RTS_FLAG_OUT_CHANNEL, 7, +const RtsPduSignature RTS_PDU_OUT_R1_A4_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL | RTS_FLAG_OUT_CHANNEL, 7, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_CHANNEL_LIFETIME, RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A5_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 3, +const RtsPduSignature RTS_PDU_OUT_R1_A5_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 3, { RTS_CMD_DESTINATION, RTS_CMD_VERSION, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A6_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 3, +const RtsPduSignature RTS_PDU_OUT_R1_A6_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 3, { RTS_CMD_DESTINATION, RTS_CMD_VERSION, RTS_CMD_CONNECTION_TIMEOUT, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A7_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, +const RtsPduSignature RTS_PDU_OUT_R1_A7_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, { RTS_CMD_DESTINATION, RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A8_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, +const RtsPduSignature RTS_PDU_OUT_R1_A8_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, { RTS_CMD_DESTINATION, RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A9_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R1_A9_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A10_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R1_A10_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R1_A11_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R1_A11_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, +const RtsPduSignature RTS_PDU_OUT_R2_A1_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A2_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, +const RtsPduSignature RTS_PDU_OUT_R2_A2_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 1, { RTS_CMD_DESTINATION, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A3_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 5, +const RtsPduSignature RTS_PDU_OUT_R2_A3_SIGNATURE = { RTS_FLAG_RECYCLE_CHANNEL, 5, { RTS_CMD_VERSION, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_COOKIE, RTS_CMD_RECEIVE_WINDOW_SIZE, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A4_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R2_A4_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A5_SIGNATURE = { RTS_FLAG_NONE, 2, +const RtsPduSignature RTS_PDU_OUT_R2_A5_SIGNATURE = { RTS_FLAG_NONE, 2, { RTS_CMD_DESTINATION, RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A6_SIGNATURE = { RTS_FLAG_NONE, 2, +const RtsPduSignature RTS_PDU_OUT_R2_A6_SIGNATURE = { RTS_FLAG_NONE, 2, { RTS_CMD_DESTINATION, RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A7_SIGNATURE = { RTS_FLAG_NONE, 3, +const RtsPduSignature RTS_PDU_OUT_R2_A7_SIGNATURE = { RTS_FLAG_NONE, 3, { RTS_CMD_DESTINATION, RTS_CMD_COOKIE, RTS_CMD_VERSION, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_A8_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, +const RtsPduSignature RTS_PDU_OUT_R2_A8_SIGNATURE = { RTS_FLAG_OUT_CHANNEL, 2, { RTS_CMD_DESTINATION, RTS_CMD_COOKIE, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_B1_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R2_B1_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_B2_SIGNATURE = { RTS_FLAG_NONE, 1, +const RtsPduSignature RTS_PDU_OUT_R2_B2_SIGNATURE = { RTS_FLAG_NONE, 1, { RTS_CMD_NEGATIVE_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_B3_SIGNATURE = { RTS_FLAG_EOF, 1, +const RtsPduSignature RTS_PDU_OUT_R2_B3_SIGNATURE = { RTS_FLAG_EOF, 1, { RTS_CMD_ANCE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_OUT_R2_C1_SIGNATURE = { RTS_FLAG_PING, 1, +const RtsPduSignature RTS_PDU_OUT_R2_C1_SIGNATURE = { RTS_FLAG_PING, 1, { 0, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_KEEP_ALIVE_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, +const RtsPduSignature RTS_PDU_KEEP_ALIVE_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, { RTS_CMD_CLIENT_KEEPALIVE, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_PING_TRAFFIC_SENT_NOTIFY_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, +const RtsPduSignature RTS_PDU_PING_TRAFFIC_SENT_NOTIFY_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, { RTS_CMD_PING_TRAFFIC_SENT_NOTIFY, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_ECHO_SIGNATURE = { RTS_FLAG_ECHO, 0, +const RtsPduSignature RTS_PDU_ECHO_SIGNATURE = { RTS_FLAG_ECHO, 0, { 0, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_PING_SIGNATURE = { RTS_FLAG_PING, 0, +const RtsPduSignature RTS_PDU_PING_SIGNATURE = { RTS_FLAG_PING, 0, { 0, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, +const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE = { RTS_FLAG_OTHER_CMD, 1, { RTS_CMD_FLOW_CONTROL_ACK, 0, 0, 0, 0, 0, 0, 0 } }; -RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE = { RTS_FLAG_OTHER_CMD, 2, +const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE = { RTS_FLAG_OTHER_CMD, 2, { RTS_CMD_DESTINATION, RTS_CMD_FLOW_CONTROL_ACK, 0, 0, 0, 0, 0, 0 } }; @@ -274,26 +274,30 @@ static const RTS_PDU_SIGNATURE_ENTRY RTS_PDU_SIGNATURE_TABLE[] = { RTS_PDU_FLOW_CONTROL_ACK, TRUE, &RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, "FlowControlAck" }, { RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION, TRUE, &RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, "FlowControlAckWithDestination" }, - { 0, 0, NULL } + { 0, FALSE, NULL, NULL } }; -BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rts_hdr_t* rts) +BOOL rts_match_pdu_signature(const RtsPduSignature* signature, + const rpcconn_rts_hdr_t* rts) { - int i; + UINT16 i; int status; - BYTE* buffer; + const BYTE* buffer; UINT32 length; UINT32 offset; UINT32 CommandType; UINT32 CommandLength; + if (!signature || !rts) + return FALSE; + if (rts->Flags != signature->Flags) return FALSE; if (rts->NumberOfCommands != signature->NumberOfCommands) return FALSE; - buffer = (BYTE*) rts; + buffer = (const BYTE*) rts; offset = RTS_PDU_HEADER_LENGTH; length = rts->frag_length - offset; @@ -305,7 +309,7 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt if (CommandType != signature->CommandTypes[i]) return FALSE; - status = rts_command_length(rpc, CommandType, &buffer[offset], length); + status = rts_command_length(CommandType, &buffer[offset], length); if (status < 0) return FALSE; @@ -318,7 +322,7 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt return TRUE; } -int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rts_hdr_t* rts) +BOOL rts_extract_pdu_signature(RtsPduSignature* signature, const rpcconn_rts_hdr_t* rts) { int i; int status; @@ -327,6 +331,10 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r UINT32 offset; UINT32 CommandType; UINT32 CommandLength; + + if (!signature || !rts) + return FALSE; + signature->Flags = rts->Flags; signature->NumberOfCommands = rts->NumberOfCommands; buffer = (BYTE*) rts; @@ -338,7 +346,7 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r CommandType = *((UINT32*) &buffer[offset]); /* CommandType (4 bytes) */ offset += 4; signature->CommandTypes[i] = CommandType; - status = rts_command_length(rpc, CommandType, &buffer[offset], length); + status = rts_command_length(CommandType, &buffer[offset], length); if (status < 0) return FALSE; @@ -348,18 +356,17 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r length = rts->frag_length - offset; } - return 0; + return TRUE; } -UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, +UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, const RTS_PDU_SIGNATURE_ENTRY** entry) { - int i, j; - RtsPduSignature* pSignature; + size_t i, j; for (i = 0; RTS_PDU_SIGNATURE_TABLE[i].SignatureId != 0; i++) { - pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; + const RtsPduSignature* pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; if (!RTS_PDU_SIGNATURE_TABLE[i].SignatureClient) continue; @@ -385,16 +392,20 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, return 0; } -int rts_print_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature) +BOOL rts_print_pdu_signature(const RtsPduSignature* signature) { UINT32 SignatureId; const RTS_PDU_SIGNATURE_ENTRY* entry; + + if (!signature) + return FALSE; + WLog_INFO(TAG, "RTS PDU Signature: Flags: 0x%04"PRIX16" NumberOfCommands: %"PRIu16"", signature->Flags, signature->NumberOfCommands); - SignatureId = rts_identify_pdu_signature(rpc, signature, &entry); + SignatureId = rts_identify_pdu_signature(signature, &entry); if (SignatureId) WLog_ERR(TAG, "Identified %s RTS PDU", entry->PduName); - return 0; + return TRUE; } diff --git a/libfreerdp/core/gateway/rts_signature.h b/libfreerdp/core/gateway/rts_signature.h index 91c30bbaa..5318a4491 100644 --- a/libfreerdp/core/gateway/rts_signature.h +++ b/libfreerdp/core/gateway/rts_signature.h @@ -39,7 +39,7 @@ struct _RTS_PDU_SIGNATURE_ENTRY { UINT32 SignatureId; BOOL SignatureClient; - RtsPduSignature* Signature; + const RtsPduSignature* Signature; const char* PduName; }; @@ -117,75 +117,74 @@ struct _RTS_PDU_SIGNATURE_ENTRY #define RTS_PDU_FLOW_CONTROL_ACK (RTS_PDU_OUT_OF_SEQUENCE | 0x00000005) #define RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION (RTS_PDU_OUT_OF_SEQUENCE | 0x00000006) -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_A1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_A2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_A3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_A1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_A2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_A3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_B1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_B2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_B3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_B1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_B2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_B3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_C1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_CONN_C2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_C1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_CONN_C2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A4_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A5_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_A6_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A4_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A5_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_A6_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_B1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R1_B2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_B1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R1_B2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R2_A1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R2_A2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R2_A3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R2_A4_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_IN_R2_A5_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R2_A1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R2_A2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R2_A3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R2_A4_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_IN_R2_A5_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A4_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A5_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A6_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A7_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A8_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A9_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A10_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R1_A11_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A4_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A5_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A6_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A7_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A8_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A9_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A10_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R1_A11_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A4_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A5_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A6_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A7_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_A8_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A4_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A5_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A6_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A7_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_A8_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_B1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_B2_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_B3_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_B1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_B2_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_B3_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_OUT_R2_C1_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_OUT_R2_C1_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_KEEP_ALIVE_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_PING_TRAFFIC_SENT_NOTIFY_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_ECHO_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_PING_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE; -FREERDP_LOCAL extern RtsPduSignature +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_KEEP_ALIVE_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_PING_TRAFFIC_SENT_NOTIFY_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_ECHO_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_PING_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE; +FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE; -FREERDP_LOCAL BOOL rts_match_pdu_signature(rdpRpc* rpc, - RtsPduSignature* signature, rpcconn_rts_hdr_t* rts); -FREERDP_LOCAL int rts_extract_pdu_signature(rdpRpc* rpc, - RtsPduSignature* signature, rpcconn_rts_hdr_t* rts); -FREERDP_LOCAL UINT32 rts_identify_pdu_signature(rdpRpc* rpc, - RtsPduSignature* signature, const RTS_PDU_SIGNATURE_ENTRY** entry); -FREERDP_LOCAL int rts_print_pdu_signature(rdpRpc* rpc, - RtsPduSignature* signature); +FREERDP_LOCAL BOOL rts_match_pdu_signature(const RtsPduSignature* signature, + const rpcconn_rts_hdr_t* rts); +FREERDP_LOCAL BOOL rts_extract_pdu_signature(RtsPduSignature* signature, + const rpcconn_rts_hdr_t* rts); +FREERDP_LOCAL UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature, + const RTS_PDU_SIGNATURE_ENTRY** entry); +FREERDP_LOCAL BOOL rts_print_pdu_signature(const RtsPduSignature* signature); #endif /* FREERDP_LIB_CORE_GATEWAY_RTS_SIGNATURE_H */ From fc9ff6d2fc8de861442af3b34354830dba1b234f Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Fri, 28 Sep 2018 12:29:29 +0200 Subject: [PATCH 09/13] Made gateway NTLM self contained. --- libfreerdp/core/gateway/ncacn_http.c | 18 ++-- libfreerdp/core/gateway/ntlm.c | 128 ++++++++++++++++++++++++++- libfreerdp/core/gateway/ntlm.h | 67 +++++--------- libfreerdp/core/gateway/rdg.c | 6 +- libfreerdp/core/gateway/rpc_bind.c | 33 ++++--- libfreerdp/core/gateway/rpc_client.c | 24 +++-- 6 files changed, 191 insertions(+), 85 deletions(-) diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index 2b9ef8535..b91f28eb5 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -76,6 +76,7 @@ BOOL rpc_ncacn_http_send_in_channel_request(RpcChannel* inChannel) BOOL continueNeeded; rdpNtlm* ntlm; HttpContext* http; + const SecBuffer* buffer; if (!inChannel || !inChannel->ntlm || !inChannel->http) return FALSE; @@ -84,7 +85,8 @@ BOOL rpc_ncacn_http_send_in_channel_request(RpcChannel* inChannel) http = inChannel->http; continueNeeded = ntlm_authenticate(ntlm); contentLength = (continueNeeded) ? 0 : 0x40000000; - s = rpc_ntlm_http_request(http, "RPC_IN_DATA", contentLength, &ntlm->outputBuffer[0]); + buffer = ntlm_client_get_output_buffer(ntlm); + s = rpc_ntlm_http_request(http, "RPC_IN_DATA", contentLength, buffer); if (!s) return -1; @@ -112,10 +114,7 @@ BOOL rpc_ncacn_http_recv_in_channel_response(RpcChannel* inChannel, crypto_base64_decode(token64, strlen(token64), &ntlmTokenData, &ntlmTokenLength); if (ntlmTokenData && ntlmTokenLength) - { - ntlm->inputBuffer[0].pvBuffer = ntlmTokenData; - ntlm->inputBuffer[0].cbBuffer = ntlmTokenLength; - } + return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength); return TRUE; } @@ -213,6 +212,7 @@ BOOL rpc_ncacn_http_send_out_channel_request(RpcChannel* outChannel, BOOL continueNeeded; rdpNtlm* ntlm; HttpContext* http; + const SecBuffer* buffer; if (!outChannel || !outChannel->ntlm || !outChannel->http) return FALSE; @@ -226,7 +226,8 @@ BOOL rpc_ncacn_http_send_out_channel_request(RpcChannel* outChannel, else contentLength = (continueNeeded) ? 0 : 120; - s = rpc_ntlm_http_request(http, "RPC_OUT_DATA", contentLength, &ntlm->outputBuffer[0]); + buffer = ntlm_client_get_output_buffer(ntlm); + s = rpc_ntlm_http_request(http, "RPC_OUT_DATA", contentLength, buffer); if (!s) return -1; @@ -256,10 +257,7 @@ BOOL rpc_ncacn_http_recv_out_channel_response(RpcChannel* outChannel, crypto_base64_decode(token64, strlen(token64), &ntlmTokenData, &ntlmTokenLength); if (ntlmTokenData && ntlmTokenLength) - { - ntlm->inputBuffer[0].pvBuffer = ntlmTokenData; - ntlm->inputBuffer[0].cbBuffer = ntlmTokenLength; - } + return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength); return TRUE; } diff --git a/libfreerdp/core/gateway/ntlm.c b/libfreerdp/core/gateway/ntlm.c index 99306db3b..0879ad008 100644 --- a/libfreerdp/core/gateway/ntlm.c +++ b/libfreerdp/core/gateway/ntlm.c @@ -35,7 +35,46 @@ #define TAG FREERDP_TAG("core.gateway.ntlm") -BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL http, char* user, char* domain, char* password, +struct rdp_ntlm +{ + BOOL http; + CtxtHandle context; + ULONG cbMaxToken; + ULONG fContextReq; + ULONG pfContextAttr; + TimeStamp expiration; + PSecBuffer pBuffer; + SecBuffer inputBuffer[2]; + SecBuffer outputBuffer[2]; + BOOL haveContext; + BOOL haveInputBuffer; + LPTSTR ServicePrincipalName; + SecBufferDesc inputBufferDesc; + SecBufferDesc outputBufferDesc; + CredHandle credentials; + BOOL confidentiality; + SecPkgInfo* pPackageInfo; + SecurityFunctionTable* table; + SEC_WINNT_AUTH_IDENTITY identity; + SecPkgContext_Sizes ContextSizes; + SecPkgContext_Bindings* Bindings; +}; + +static ULONG cast_from_size_(size_t size, const char* fkt, const char* file, int line) +{ + if (size > ULONG_MAX) + { + WLog_ERR(TAG, "[%s %s:%d] Size %"PRIdz" is larger than INT_MAX %lu", fkt, file, line, size, + ULONG_MAX); + return 0; + } + + return (ULONG) size; +} + +#define cast_from_size(size) cast_from_size_(size, __FUNCTION__, __FILE__, __LINE__) + +BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL http, LPCTSTR user, LPCTSTR domain, LPCTSTR password, SecPkgContext_Bindings* Bindings) { SECURITY_STATUS status; @@ -96,7 +135,7 @@ BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL http, char* user, char* domain, char* return TRUE; } -BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, char* hostname) +BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, LPCTSTR hostname) { BOOL status = FALSE; DWORD SpnLength = 0; @@ -335,3 +374,88 @@ void ntlm_free(rdpNtlm* ntlm) ntlm_client_uninit(ntlm); free(ntlm); } + +SSIZE_T ntlm_client_get_context_max_size(rdpNtlm* ntlm) +{ + if (!ntlm) + return -1; + + if (ntlm->ContextSizes.cbMaxSignature > UINT16_MAX) + { + WLog_ERR(TAG, "QueryContextAttributes SECPKG_ATTR_SIZES ContextSizes.cbMaxSignature > 0xFFFF"); + return -1; + } + + return ntlm->ContextSizes.cbMaxSignature; +} + +SSIZE_T ntlm_client_query_auth_size(rdpNtlm* ntlm) +{ + SECURITY_STATUS status; + + if (!ntlm || !ntlm->table || !ntlm->table->QueryContextAttributes) + return -1; + + status = ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, + &ntlm->ContextSizes); + + if (status != SEC_E_OK) + { + WLog_ERR(TAG, "QueryContextAttributes SECPKG_ATTR_SIZES failure %s [0x%08"PRIX32"]", + GetSecurityStatusString(status), status); + return -1; + } + + return ntlm_client_get_context_max_size(ntlm); +} + +BOOL ntlm_client_encrypt(rdpNtlm* ntlm, size_t foo, SecBufferDesc* Message, size_t sequence) +{ + SECURITY_STATUS encrypt_status; + const ULONG f = cast_from_size(foo); + const ULONG s = cast_from_size(sequence); + + if (!ntlm || !Message) + return FALSE; + + encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, f, Message, s); + + if (encrypt_status != SEC_E_OK) + { + WLog_ERR(TAG, "EncryptMessage status %s [0x%08"PRIX32"]", + GetSecurityStatusString(encrypt_status), encrypt_status); + return FALSE; + } + + return TRUE; +} + +BOOL ntlm_client_set_input_buffer(rdpNtlm* ntlm, BOOL copy, const void* data, size_t size) +{ + if (!ntlm || !data || (size == 0)) + return FALSE; + + ntlm->inputBuffer[0].cbBuffer = cast_from_size(size); + + if (copy) + { + ntlm->inputBuffer[0].pvBuffer = malloc(size); + + if (!ntlm->inputBuffer[0].pvBuffer) + return FALSE; + + memcpy(ntlm->inputBuffer[0].pvBuffer, data, size); + } + else + ntlm->inputBuffer[0].pvBuffer = (void*)data; + + return TRUE; +} + +const SecBuffer* ntlm_client_get_output_buffer(rdpNtlm* ntlm) +{ + if (!ntlm) + return NULL; + + return &ntlm->outputBuffer[0]; +} diff --git a/libfreerdp/core/gateway/ntlm.h b/libfreerdp/core/gateway/ntlm.h index 29bf53b9d..47165cd0f 100644 --- a/libfreerdp/core/gateway/ntlm.h +++ b/libfreerdp/core/gateway/ntlm.h @@ -24,57 +24,30 @@ typedef struct rdp_ntlm rdpNtlm; -#include "../tcp.h" -#include "../transport.h" - -#include "rts.h" -#include "http.h" - -#include -#include -#include -#include +#include #include - #include -#include -#include - -struct rdp_ntlm -{ - BOOL http; - CtxtHandle context; - ULONG cbMaxToken; - ULONG fContextReq; - ULONG pfContextAttr; - TimeStamp expiration; - PSecBuffer pBuffer; - SecBuffer inputBuffer[2]; - SecBuffer outputBuffer[2]; - BOOL haveContext; - BOOL haveInputBuffer; - LPTSTR ServicePrincipalName; - SecBufferDesc inputBufferDesc; - SecBufferDesc outputBufferDesc; - CredHandle credentials; - BOOL confidentiality; - SecPkgInfo* pPackageInfo; - SecurityFunctionTable* table; - SEC_WINNT_AUTH_IDENTITY identity; - SecPkgContext_Sizes ContextSizes; - SecPkgContext_Bindings* Bindings; -}; - -FREERDP_LOCAL BOOL ntlm_authenticate(rdpNtlm* ntlm); - -FREERDP_LOCAL BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL confidentiality, - char* user, - char* domain, char* password, SecPkgContext_Bindings* Bindings); - -FREERDP_LOCAL BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, - char* hostname); FREERDP_LOCAL rdpNtlm* ntlm_new(void); FREERDP_LOCAL void ntlm_free(rdpNtlm* ntlm); +FREERDP_LOCAL BOOL ntlm_authenticate(rdpNtlm* ntlm); + +FREERDP_LOCAL BOOL ntlm_client_init(rdpNtlm* ntlm, BOOL confidentiality, + LPCTSTR user, LPCTSTR domain, + LPCTSTR password, SecPkgContext_Bindings* Bindings); + +FREERDP_LOCAL BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, + LPCTSTR hostname); + +FREERDP_LOCAL SSIZE_T ntlm_client_query_auth_size(rdpNtlm* ntlm); +FREERDP_LOCAL SSIZE_T ntlm_client_get_context_max_size(rdpNtlm* ntlm); + +FREERDP_LOCAL BOOL ntlm_client_encrypt(rdpNtlm* ntlm, size_t foo, SecBufferDesc* Message, + size_t sequence); + +FREERDP_LOCAL BOOL ntlm_client_set_input_buffer(rdpNtlm* ntlm, BOOL copy, const void* data, + size_t size); +FREERDP_LOCAL const SecBuffer* ntlm_client_get_output_buffer(rdpNtlm* ntlm); + #endif /* FREERDP_LIB_CORE_GATEWAY_NTLM_H */ diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c index a60c1ee2b..624843723 100644 --- a/libfreerdp/core/gateway/rdg.c +++ b/libfreerdp/core/gateway/rdg.c @@ -300,7 +300,7 @@ static BOOL rdg_send_channel_create(rdpRdg* rdg) static BOOL rdg_set_ntlm_auth_header(rdpNtlm* ntlm, HttpRequest* request) { - SecBuffer* ntlmToken = ntlm->outputBuffer; + const SecBuffer* ntlmToken = ntlm_client_get_output_buffer(ntlm); char* base64NtlmToken = NULL; if (ntlmToken) @@ -388,8 +388,8 @@ static BOOL rdg_handle_ntlm_challenge(rdpNtlm* ntlm, HttpResponse* response) if (ntlmTokenData && ntlmTokenLength) { - ntlm->inputBuffer[0].pvBuffer = ntlmTokenData; - ntlm->inputBuffer[0].cbBuffer = ntlmTokenLength; + if (!ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength)) + return FALSE; } ntlm_authenticate(ntlm); diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c index 81516d937..913489e25 100644 --- a/libfreerdp/core/gateway/rpc_bind.c +++ b/libfreerdp/core/gateway/rpc_bind.c @@ -119,6 +119,7 @@ int rpc_send_bind_pdu(rdpRpc* rpc) freerdp* instance = (freerdp*) settings->instance; RpcVirtualConnection* connection = rpc->VirtualConnection; RpcInChannel* inChannel = connection->DefaultInChannel; + const SecBuffer* sbuffer; WLog_DBG(TAG, "Sending Bind PDU"); ntlm_free(rpc->ntlm); rpc->ntlm = ntlm_new(); @@ -172,9 +173,14 @@ int rpc_send_bind_pdu(rdpRpc* rpc) if (!bind_pdu) return -1; + sbuffer = ntlm_client_get_output_buffer(rpc->ntlm); + + if (!sbuffer) + return -1; + rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu); - bind_pdu->auth_length = (UINT16) rpc->ntlm->outputBuffer[0].cbBuffer; - bind_pdu->auth_verifier.auth_value = rpc->ntlm->outputBuffer[0].pvBuffer; + bind_pdu->auth_length = (UINT16) sbuffer->cbBuffer; + bind_pdu->auth_verifier.auth_value = sbuffer->pvBuffer; bind_pdu->ptype = PTYPE_BIND; bind_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_SUPPORT_HEADER_SIGN | PFC_CONC_MPX; bind_pdu->call_id = 2; @@ -293,16 +299,17 @@ int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) rpcconn_hdr_t* header; header = (rpcconn_hdr_t*) buffer; WLog_DBG(TAG, "Receiving BindAck PDU"); - rpc->max_recv_frag = header->bind_ack.max_xmit_frag; - rpc->max_xmit_frag = header->bind_ack.max_recv_frag; - rpc->ntlm->inputBuffer[0].cbBuffer = header->common.auth_length; - rpc->ntlm->inputBuffer[0].pvBuffer = malloc(header->common.auth_length); - if (!rpc->ntlm->inputBuffer[0].pvBuffer) + if (!rpc || !rpc->ntlm) return -1; + rpc->max_recv_frag = header->bind_ack.max_xmit_frag; + rpc->max_xmit_frag = header->bind_ack.max_recv_frag; auth_data = buffer + (header->common.frag_length - header->common.auth_length); - CopyMemory(rpc->ntlm->inputBuffer[0].pvBuffer, auth_data, header->common.auth_length); + + if (!ntlm_client_set_input_buffer(rpc->ntlm, TRUE, auth_data, header->common.auth_length)) + return -1; + ntlm_authenticate(rpc->ntlm); return (int) length; } @@ -320,6 +327,7 @@ int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc) BYTE* buffer; UINT32 offset; UINT32 length; + const SecBuffer* sbuffer; RpcClientCall* clientCall; rpcconn_rpc_auth_3_hdr_t* auth_3_pdu; RpcVirtualConnection* connection = rpc->VirtualConnection; @@ -330,9 +338,14 @@ int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc) if (!auth_3_pdu) return -1; + sbuffer = ntlm_client_get_output_buffer(rpc->ntlm); + + if (!sbuffer) + return -1; + rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) auth_3_pdu); - auth_3_pdu->auth_length = (UINT16) rpc->ntlm->outputBuffer[0].cbBuffer; - auth_3_pdu->auth_verifier.auth_value = rpc->ntlm->outputBuffer[0].pvBuffer; + auth_3_pdu->auth_length = (UINT16) sbuffer->cbBuffer; + auth_3_pdu->auth_verifier.auth_value = sbuffer->pvBuffer; auth_3_pdu->ptype = PTYPE_RPC_AUTH_3; auth_3_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_CONC_MPX; auth_3_pdu->call_id = 2; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index c8dc8f21b..58ad7ee2b 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -878,11 +878,11 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) SecBufferDesc Message; RpcClientCall* clientCall = NULL; rdpNtlm* ntlm; - SECURITY_STATUS encrypt_status; rpcconn_request_hdr_t* request_pdu = NULL; RpcVirtualConnection* connection; RpcInChannel* inChannel; size_t length; + SSIZE_T size; if (!s || !rpc) return FALSE; @@ -890,7 +890,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) ntlm = rpc->ntlm; connection = rpc->VirtualConnection; - if (!ntlm || !ntlm->table) + if (!ntlm) { WLog_ERR(TAG, "invalid ntlm context"); return FALSE; @@ -906,10 +906,8 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) Stream_SealLength(s); length = Stream_Length(s); - status = ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, - &ntlm->ContextSizes); - if (status != SEC_E_OK) + if (ntlm_client_query_auth_size(ntlm) < 0) { WLog_ERR(TAG, "QueryContextAttributes SECPKG_ATTR_SIZES failure %s [0x%08"PRIX32"]", GetSecurityStatusString(status), status); @@ -922,10 +920,15 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) if (!request_pdu) return FALSE; + size = ntlm_client_get_context_max_size(ntlm); + + if (size < 0) + goto fail; + rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); request_pdu->ptype = PTYPE_REQUEST; request_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG; - request_pdu->auth_length = (UINT16) ntlm->ContextSizes.cbMaxSignature; + request_pdu->auth_length = (UINT16) size; request_pdu->call_id = rpc->CallId++; request_pdu->alloc_hint = length; request_pdu->p_cont_id = 0x0000; @@ -969,7 +972,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) Buffers[1].BufferType = SECBUFFER_TOKEN; /* signature */ Buffers[0].pvBuffer = buffer; Buffers[0].cbBuffer = offset; - Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature; + Buffers[1].cbBuffer = size; Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer); if (!Buffers[1].pvBuffer) @@ -978,14 +981,9 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) Message.cBuffers = 2; Message.ulVersion = SECBUFFER_VERSION; Message.pBuffers = (PSecBuffer) &Buffers; - encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++); - if (encrypt_status != SEC_E_OK) - { - WLog_ERR(TAG, "EncryptMessage status %s [0x%08"PRIX32"]", - GetSecurityStatusString(encrypt_status), encrypt_status); + if (!ntlm_client_encrypt(ntlm, 0, &Message, rpc->SendSeqNum++)) goto fail; - } CopyMemory(&buffer[offset], Buffers[1].pvBuffer, Buffers[1].cbBuffer); offset += Buffers[1].cbBuffer; From 8c92f3436d35d9a561e8d9f14857704592b14ad2 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Tue, 9 Oct 2018 14:16:27 +0200 Subject: [PATCH 10/13] Fixed argument name for ntlm_client_encrypt --- libfreerdp/core/gateway/ntlm.c | 5 ++--- libfreerdp/core/gateway/ntlm.h | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/libfreerdp/core/gateway/ntlm.c b/libfreerdp/core/gateway/ntlm.c index 0879ad008..4982a5b41 100644 --- a/libfreerdp/core/gateway/ntlm.c +++ b/libfreerdp/core/gateway/ntlm.c @@ -409,16 +409,15 @@ SSIZE_T ntlm_client_query_auth_size(rdpNtlm* ntlm) return ntlm_client_get_context_max_size(ntlm); } -BOOL ntlm_client_encrypt(rdpNtlm* ntlm, size_t foo, SecBufferDesc* Message, size_t sequence) +BOOL ntlm_client_encrypt(rdpNtlm* ntlm, ULONG fQOP, SecBufferDesc* Message, size_t sequence) { SECURITY_STATUS encrypt_status; - const ULONG f = cast_from_size(foo); const ULONG s = cast_from_size(sequence); if (!ntlm || !Message) return FALSE; - encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, f, Message, s); + encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, fQOP, Message, s); if (encrypt_status != SEC_E_OK) { diff --git a/libfreerdp/core/gateway/ntlm.h b/libfreerdp/core/gateway/ntlm.h index 47165cd0f..f413622d8 100644 --- a/libfreerdp/core/gateway/ntlm.h +++ b/libfreerdp/core/gateway/ntlm.h @@ -43,7 +43,7 @@ FREERDP_LOCAL BOOL ntlm_client_make_spn(rdpNtlm* ntlm, LPCTSTR ServiceClass, FREERDP_LOCAL SSIZE_T ntlm_client_query_auth_size(rdpNtlm* ntlm); FREERDP_LOCAL SSIZE_T ntlm_client_get_context_max_size(rdpNtlm* ntlm); -FREERDP_LOCAL BOOL ntlm_client_encrypt(rdpNtlm* ntlm, size_t foo, SecBufferDesc* Message, +FREERDP_LOCAL BOOL ntlm_client_encrypt(rdpNtlm* ntlm, ULONG fQOP, SecBufferDesc* Message, size_t sequence); FREERDP_LOCAL BOOL ntlm_client_set_input_buffer(rdpNtlm* ntlm, BOOL copy, const void* data, From 65bfb67f7c3b651c9c0b660f5db36a05f0a838c9 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Tue, 9 Oct 2018 14:19:05 +0200 Subject: [PATCH 11/13] Fixed rpc_client_write_call resource cleanup. --- libfreerdp/core/gateway/rpc_client.c | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index 58ad7ee2b..ce75b64f3 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -874,7 +874,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) UINT32 offset; BYTE* buffer = NULL; UINT32 stub_data_pad; - SecBuffer Buffers[2]; + SecBuffer Buffers[2] = { 0 }; SecBufferDesc Message; RpcClientCall* clientCall = NULL; rdpNtlm* ntlm; @@ -884,25 +884,28 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) size_t length; SSIZE_T size; - if (!s || !rpc) + if (!s) return FALSE; + if (!rpc) + goto fail; + ntlm = rpc->ntlm; connection = rpc->VirtualConnection; if (!ntlm) { WLog_ERR(TAG, "invalid ntlm context"); - return FALSE; + goto fail; } if (!connection) - return FALSE; + goto fail; inChannel = connection->DefaultInChannel; if (!inChannel) - return FALSE; + goto fail; Stream_SealLength(s); length = Stream_Length(s); @@ -911,14 +914,13 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum) { WLog_ERR(TAG, "QueryContextAttributes SECPKG_ATTR_SIZES failure %s [0x%08"PRIX32"]", GetSecurityStatusString(status), status); - return FALSE; + goto fail; } - ZeroMemory(&Buffers, sizeof(Buffers)); request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t)); if (!request_pdu) - return FALSE; + goto fail; size = ntlm_client_get_context_max_size(ntlm); From 166bdf018c76caf50e6cecc65015c1869f4c11f0 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Tue, 9 Oct 2018 14:24:39 +0200 Subject: [PATCH 12/13] Fixed return value of rpc_channel_tls_connect --- libfreerdp/core/gateway/rpc.c | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index 68a32b952..c48526b6d 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -644,7 +644,7 @@ static void rpc_virtual_connection_free(RpcVirtualConnection* connection) free(connection); } -static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) +static BOOL rpc_channel_tls_connect(RpcChannel* channel, int timeout) { int sockfd; rdpTls* tls; @@ -685,20 +685,20 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) bufferedBio = BIO_push(bufferedBio, socketBio); if (!BIO_set_nonblock(bufferedBio, TRUE)) - return -1; + return FALSE; if (channel->client->isProxy) { if (!proxy_connect(settings, bufferedBio, proxyUsername, proxyPassword, settings->GatewayHostname, settings->GatewayPort)) - return -1; + return FALSE; } channel->bio = bufferedBio; tls = channel->tls = tls_new(settings); if (!tls) - return -1; + return FALSE; tls->hostname = settings->GatewayHostname; tls->port = settings->GatewayPort; @@ -718,10 +718,10 @@ static int rpc_channel_tls_connect(RpcChannel* channel, int timeout) freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); } - return -1; + return FALSE; } - return 1; + return TRUE; } static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) @@ -735,7 +735,7 @@ static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout) /* Connect IN Channel */ - if (rpc_channel_tls_connect(&inChannel->common, timeout) < 0) + if (!rpc_channel_tls_connect(&inChannel->common, timeout)) return -1; rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED); @@ -768,7 +768,7 @@ static int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout) /* Connect OUT Channel */ - if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) + if (!rpc_channel_tls_connect(&outChannel->common, timeout)) return -1; rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); @@ -799,7 +799,7 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout) /* Connect OUT Channel */ - if (rpc_channel_tls_connect(&outChannel->common, timeout) < 0) + if (!rpc_channel_tls_connect(&outChannel->common, timeout)) return -1; rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED); From 766a66a7c289e3886e28b934ce94bb71bc6b2274 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Tue, 9 Oct 2018 14:24:58 +0200 Subject: [PATCH 13/13] Fixed stream get position. --- libfreerdp/core/gateway/tsg.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index 5d519819a..45a7c19ee 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -560,7 +560,7 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, /* 4-byte alignment */ { - UINT32 offset = Stream_Pointer(pdu->s); + UINT32 offset = Stream_GetPosition(pdu->s); if (!Stream_SafeSeek(pdu->s, rpc_offset_align(&offset, 4))) goto fail;