Implement support for RDS AAD

Have a working implementation of the RDS AAD enhanced security mechanism
for Azure AD logons
This commit is contained in:
fifthdegree 2023-03-07 20:06:47 -05:00 committed by akallabeth
parent 5df4d4c934
commit 4cbfa006f2
19 changed files with 947 additions and 36 deletions

View File

@ -60,6 +60,8 @@
#include <freerdp/log.h> #include <freerdp/log.h>
#define TAG CLIENT_TAG("common") #define TAG CLIENT_TAG("common")
#define OAUTH2_CLIENT_ID "5177bc73-fd99-4c77-a90c-76844c9b6999"
static BOOL freerdp_client_common_new(freerdp* instance, rdpContext* context) static BOOL freerdp_client_common_new(freerdp* instance, rdpContext* context)
{ {
RDP_CLIENT_ENTRY_POINTS* pEntryPoints; RDP_CLIENT_ENTRY_POINTS* pEntryPoints;
@ -74,6 +76,7 @@ static BOOL freerdp_client_common_new(freerdp* instance, rdpContext* context)
instance->VerifyChangedCertificateEx = client_cli_verify_changed_certificate_ex; instance->VerifyChangedCertificateEx = client_cli_verify_changed_certificate_ex;
instance->PresentGatewayMessage = client_cli_present_gateway_message; instance->PresentGatewayMessage = client_cli_present_gateway_message;
instance->LogonErrorInfo = client_cli_logon_error_info; instance->LogonErrorInfo = client_cli_logon_error_info;
instance->GetAadAuthCode = client_cli_get_aad_auth_code;
pEntryPoints = instance->pClientEntryPoints; pEntryPoints = instance->pClientEntryPoints;
WINPR_ASSERT(pEntryPoints); WINPR_ASSERT(pEntryPoints);
@ -934,6 +937,35 @@ BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isD
return TRUE; return TRUE;
} }
BOOL client_cli_get_aad_auth_code(const char* hostname, char** code)
{
size_t len = 0;
char* p = NULL;
WINPR_ASSERT(hostname);
WINPR_ASSERT(code);
*code = NULL;
printf(
"Browse to: "
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=" OAUTH2_CLIENT_ID
"&response_type=code"
"&scope=ms-device-service%%3A%%2F%%2Ftermsrv.wvd.microsoft.com%%2Fname%%2F%s%%2Fuser_"
"impersonation"
"&redirect_uri=ms-appx-web%%3a%%2f%%2fMicrosoft.AAD.BrokerPlugin%%2f5177bc73-fd99-4c77-"
"a90c-76844c9b6999\n",
hostname);
printf("Paste authorization code here: ");
if (GetLine(code, &len, stdin) < 0)
return FALSE;
p = strpbrk(*code, "\r\n");
if (p)
*p = 0;
return TRUE;
}
BOOL client_auto_reconnect(freerdp* instance) BOOL client_auto_reconnect(freerdp* instance)
{ {
return client_auto_reconnect_ex(instance, NULL); return client_auto_reconnect_ex(instance, NULL);

View File

@ -3598,6 +3598,10 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings,
} }
CommandLineSwitchCase(arg, "sec") CommandLineSwitchCase(arg, "sec")
{ {
BOOL RdpSecurity = FALSE;
BOOL TlsSecurity = FALSE;
BOOL NlaSecurity = FALSE;
BOOL ExtSecurity = FALSE;
size_t count = 0, x; size_t count = 0, x;
char** ptr = CommandLineParseCommaSeparatedValues(arg->Value, &count); char** ptr = CommandLineParseCommaSeparatedValues(arg->Value, &count);
if (count == 0) if (count == 0)
@ -3628,6 +3632,8 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings,
id = FreeRDP_NlaSecurity; id = FreeRDP_NlaSecurity;
else if (option_starts_with("ext", cur)) /* NLA Extended */ else if (option_starts_with("ext", cur)) /* NLA Extended */
id = FreeRDP_ExtSecurity; id = FreeRDP_ExtSecurity;
else if (option_equals("aad", cur)) /* RDSAAD */
id = FreeRDP_AadSecurity;
else else
{ {
WLog_ERR(TAG, "unknown protocol security: %s", arg->Value); WLog_ERR(TAG, "unknown protocol security: %s", arg->Value);
@ -3643,8 +3649,9 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings,
if (singleOptionWithoutOnOff != 0) if (singleOptionWithoutOnOff != 0)
{ {
const size_t options[] = { FreeRDP_UseRdpSecurityLayer, FreeRDP_RdpSecurity, const size_t options[] = { FreeRDP_AadSecurity, FreeRDP_UseRdpSecurityLayer,
FreeRDP_NlaSecurity, FreeRDP_TlsSecurity }; FreeRDP_RdpSecurity, FreeRDP_NlaSecurity,
FreeRDP_TlsSecurity };
for (size_t x = 0; x < ARRAYSIZE(options); x++) for (size_t x = 0; x < ARRAYSIZE(options); x++)
{ {

View File

@ -155,6 +155,8 @@ extern "C"
FREERDP_API int client_cli_logon_error_info(freerdp* instance, UINT32 data, UINT32 type); FREERDP_API int client_cli_logon_error_info(freerdp* instance, UINT32 data, UINT32 type);
FREERDP_API BOOL client_cli_get_aad_auth_code(const char* hostname, char** code);
FREERDP_API void FREERDP_API void
freerdp_client_OnChannelConnectedEventHandler(void* context, freerdp_client_OnChannelConnectedEventHandler(void* context,
const ChannelConnectedEventArgs* e); const ChannelConnectedEventArgs* e);

View File

@ -83,6 +83,7 @@ extern "C"
CONNECTION_STATE_INITIAL, CONNECTION_STATE_INITIAL,
CONNECTION_STATE_NEGO, CONNECTION_STATE_NEGO,
CONNECTION_STATE_NLA, CONNECTION_STATE_NLA,
CONNECTION_STATE_AAD,
CONNECTION_STATE_MCS_CREATE_REQUEST, CONNECTION_STATE_MCS_CREATE_REQUEST,
CONNECTION_STATE_MCS_CREATE_RESPONSE, CONNECTION_STATE_MCS_CREATE_RESPONSE,
CONNECTION_STATE_MCS_ERECT_DOMAIN, CONNECTION_STATE_MCS_ERECT_DOMAIN,
@ -138,6 +139,7 @@ extern "C"
char** domain, rdp_auth_reason reason); char** domain, rdp_auth_reason reason);
typedef BOOL (*pChooseSmartcard)(freerdp* instance, SmartcardCertInfo** cert_list, DWORD count, typedef BOOL (*pChooseSmartcard)(freerdp* instance, SmartcardCertInfo** cert_list, DWORD count,
DWORD* choice, BOOL gateway); DWORD* choice, BOOL gateway);
typedef BOOL (*pGetAadAuthCode)(const char* hostname, char** code);
/** @brief Callback used if user interaction is required to accept /** @brief Callback used if user interaction is required to accept
* an unknown certificate. * an unknown certificate.
@ -530,7 +532,10 @@ owned by rdpRdp */
Callback for choosing a smartcard for logon. Callback for choosing a smartcard for logon.
Used when multiple smartcards are available. Returns an index into a list Used when multiple smartcards are available. Returns an index into a list
of SmartcardCertInfo pointers */ of SmartcardCertInfo pointers */
UINT64 paddingE[80 - 71]; /* 71 */ ALIGN64 pGetAadAuthCode GetAadAuthCode; /* (offset 71)
Callback for obtaining an oauth2 authorization
code for RDS AAD authentication */
UINT64 paddingE[80 - 72]; /* 71 */
}; };
struct rdp_channel_handles struct rdp_channel_handles

View File

@ -632,6 +632,7 @@ typedef struct
#define FreeRDP_TlsSecretsFile (1109) #define FreeRDP_TlsSecretsFile (1109)
#define FreeRDP_AuthenticationPackageList (1110) #define FreeRDP_AuthenticationPackageList (1110)
#define FreeRDP_RdstlsSecurity (1111) #define FreeRDP_RdstlsSecurity (1111)
#define FreeRDP_AadSecurity (1112)
#define FreeRDP_MstscCookieMode (1152) #define FreeRDP_MstscCookieMode (1152)
#define FreeRDP_CookieMaxLength (1153) #define FreeRDP_CookieMaxLength (1153)
#define FreeRDP_PreconnectionId (1154) #define FreeRDP_PreconnectionId (1154)
@ -1144,7 +1145,8 @@ struct rdp_settings
ALIGN64 char* TlsSecretsFile; /* 1109 */ ALIGN64 char* TlsSecretsFile; /* 1109 */
ALIGN64 char* AuthenticationPackageList; /* 1110 */ ALIGN64 char* AuthenticationPackageList; /* 1110 */
ALIGN64 BOOL RdstlsSecurity; /* 1111 */ ALIGN64 BOOL RdstlsSecurity; /* 1111 */
UINT64 padding1152[1152 - 1112]; /* 1112 */ ALIGN64 BOOL AadSecurity; /* 1112 */
UINT64 padding1152[1152 - 1113]; /* 1113 */
/* Connection Cookie */ /* Connection Cookie */
ALIGN64 BOOL MstscCookieMode; /* 1152 */ ALIGN64 BOOL MstscCookieMode; /* 1152 */

View File

@ -57,6 +57,9 @@ BOOL freerdp_settings_get_bool(const rdpSettings* settings, size_t id)
switch (id) switch (id)
{ {
case FreeRDP_AadSecurity:
return settings->AadSecurity;
case FreeRDP_AllowCacheWaitingList: case FreeRDP_AllowCacheWaitingList:
return settings->AllowCacheWaitingList; return settings->AllowCacheWaitingList;
@ -614,6 +617,10 @@ BOOL freerdp_settings_set_bool(rdpSettings* settings, size_t id, BOOL val)
switch (id) switch (id)
{ {
case FreeRDP_AadSecurity:
settings->AadSecurity = cnv.c;
break;
case FreeRDP_AllowCacheWaitingList: case FreeRDP_AllowCacheWaitingList:
settings->AllowCacheWaitingList = cnv.c; settings->AllowCacheWaitingList = cnv.c;
break; break;

View File

@ -27,6 +27,7 @@ struct settings_str_entry
const char* str; const char* str;
}; };
static const struct settings_str_entry settings_map[] = { static const struct settings_str_entry settings_map[] = {
{ FreeRDP_AadSecurity, FREERDP_SETTINGS_TYPE_BOOL, "FreeRDP_AadSecurity" },
{ FreeRDP_AllowCacheWaitingList, FREERDP_SETTINGS_TYPE_BOOL, "FreeRDP_AllowCacheWaitingList" }, { FreeRDP_AllowCacheWaitingList, FREERDP_SETTINGS_TYPE_BOOL, "FreeRDP_AllowCacheWaitingList" },
{ FreeRDP_AllowDesktopComposition, FREERDP_SETTINGS_TYPE_BOOL, { FreeRDP_AllowDesktopComposition, FREERDP_SETTINGS_TYPE_BOOL,
"FreeRDP_AllowDesktopComposition" }, "FreeRDP_AllowDesktopComposition" },

View File

@ -134,7 +134,10 @@ set(${MODULE_PREFIX}_SRCS
credssp_auth.c credssp_auth.c
credssp_auth.h credssp_auth.h
rdstls.c rdstls.c
rdstls.h) rdstls.h
aad.c
aad.h
)
set(${MODULE_PREFIX}_SRCS ${${MODULE_PREFIX}_SRCS} ${${MODULE_PREFIX}_GATEWAY_SRCS}) set(${MODULE_PREFIX}_SRCS ${${MODULE_PREFIX}_SRCS} ${${MODULE_PREFIX}_GATEWAY_SRCS})

623
libfreerdp/core/aad.c Normal file
View File

@ -0,0 +1,623 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
* Network Level Authentication (NLA)
*
* Copyright 2023 Isaac Klein <fifthdegree@protonmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <freerdp/config.h>
#include <stdio.h>
#include <freerdp/crypto/crypto.h>
#include <freerdp/utils/json.h>
#include <winpr/crypto.h>
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/core_names.h>
#include <openssl/err.h>
#include "transport.h"
#include "aad.h"
#define TAG FREERDP_TAG("aad")
#define LOG_ERROR_AND_GOTO(label, ...) \
do \
{ \
WLog_ERR(TAG, __VA_ARGS__); \
goto label; \
} while (0);
#define LOG_ERROR_AND_RETURN(ret, ...) \
do \
{ \
WLog_ERR(TAG, __VA_ARGS__); \
return ret; \
} while (0);
#define XFREE(x) \
do \
{ \
free(x); \
x = NULL; \
} while (0);
#define OAUTH2_CLIENT_ID "5177bc73-fd99-4c77-a90c-76844c9b6999"
static const char* auth_server = "login.microsoftonline.com";
static const char nonce_http_request[] = ""
"POST /common/oauth2/token HTTP/1.1\r\n"
"Host: login.microsoftonline.com\r\n"
"Content-Type: application/x-www-form-urlencoded\r\n"
"Content-Length: 24\r\n"
"\r\n"
"grant_type=srv_challenge"
"\r\n\r\n";
static const char token_http_request_header[] =
""
"POST /common/oauth2/v2.0/token HTTP/1.1\r\n"
"Host: login.microsoftonline.com\r\n"
"Content-Type: application/x-www-form-urlencoded\r\n"
"Content-Length: %lu\r\n"
"\r\n";
static const char token_http_request_body[] =
""
"client_id=" OAUTH2_CLIENT_ID "&grant_type=authorization_code"
"&code=%s"
"&scope=ms-device-service%%3A%%2F%%2Ftermsrv.wvd.microsoft.com%%2Fname%%2F%s%%2Fuser_"
"impersonation"
"&req_cnf=%s"
"&redirect_uri=ms-appx-web%%3a%%2f%%2fMicrosoft.AAD.BrokerPlugin%%2f5177bc73-fd99-4c77-a90c-"
"76844c9b6999"
"\r\n\r\n";
struct rdp_aad
{
enum AAD_STATE state;
rdpContext* rdpcontext;
rdpTransport* transport;
char* access_token;
EVP_PKEY* pop_key;
char* kid;
char* nonce;
char* hostname;
};
static int alloc_sprintf(char** s, const char* template, ...);
static BOOL get_encoded_rsa_params(EVP_PKEY* pkey, char** e, char** n);
static BOOL generate_pop_key(rdpAad* aad);
static BOOL read_http_message(BIO* bio, long* status_code, char** content, size_t* content_length);
static int print_error(const char* str, size_t len, void* u)
{
WLog_ERR(TAG, "%s", str);
return 1;
}
rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
{
WINPR_ASSERT(transport);
WINPR_ASSERT(context);
rdpAad* aad = (rdpAad*)calloc(1, sizeof(rdpAad));
if (!aad)
return NULL;
aad->rdpcontext = context;
aad->transport = transport;
return aad;
}
int aad_client_begin(rdpAad* aad)
{
int ret = -1;
SSL_CTX* ssl_ctx = NULL;
BIO* bio = NULL;
char* auth_code = NULL;
char *buffer = NULL, *req_header = NULL, *req_body = NULL;
size_t length = 0;
const char* hostname = NULL;
char* p = NULL;
const char* token = NULL;
long status_code;
JSON* json = NULL;
JSON* prop = NULL;
WINPR_ASSERT(aad);
WINPR_ASSERT(aad->rdpcontext);
rdpSettings* settings = aad->rdpcontext->settings;
WINPR_ASSERT(settings);
freerdp* instance = aad->rdpcontext->instance;
WINPR_ASSERT(instance);
/* Get the host part of the hostname */
hostname = freerdp_settings_get_string(settings, FreeRDP_ServerHostname);
if (!hostname || !(aad->hostname = _strdup(hostname)))
LOG_ERROR_AND_GOTO(fail, "Unable to get hostname");
if ((p = strchr(aad->hostname, '.')))
*p = '\0';
if (!generate_pop_key(aad))
LOG_ERROR_AND_GOTO(fail, "Unable to generate pop key");
/* Obtain an oauth authorization code */
if (!instance->GetAadAuthCode || !instance->GetAadAuthCode(aad->hostname, &auth_code))
LOG_ERROR_AND_GOTO(fail, "Unable to obtain authorization code");
/* Set up an ssl connection to the authorization server */
if (!(ssl_ctx = SSL_CTX_new(TLS_client_method())))
LOG_ERROR_AND_GOTO(fail, "Error setting up SSL context");
SSL_CTX_set_default_verify_paths(ssl_ctx);
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_AUTO_RETRY);
if (!(bio = BIO_new_ssl_connect(ssl_ctx)))
LOG_ERROR_AND_GOTO(fail, "Error setting up connection");
BIO_set_conn_hostname(bio, auth_server);
BIO_set_conn_port(bio, "https");
/* Construct and send the token request message */
length = alloc_sprintf(&req_body, token_http_request_body, auth_code, aad->hostname, aad->kid);
if (length < 0)
goto fail;
if (alloc_sprintf(&req_header, token_http_request_header, length) < 0)
goto fail;
WLog_DBG(TAG, "HTTP access token request: %s%s", req_header, req_body);
ERR_clear_error();
if (BIO_write(bio, req_header, strlen(req_header)) < 0)
{
ERR_print_errors_cb(print_error, NULL);
goto fail;
}
ERR_clear_error();
if (BIO_write(bio, req_body, strlen(req_body)) < 0)
{
ERR_print_errors_cb(print_error, NULL);
goto fail;
}
/* Read in the response */
if (!read_http_message(bio, &status_code, &buffer, &length))
LOG_ERROR_AND_GOTO(fail, "Unable to read access token HTTP response");
WLog_DBG(TAG, "HTTP access token response: %s", buffer);
if (status_code != 200)
LOG_ERROR_AND_GOTO(fail, "Received status code: %li", status_code);
/* Extract the access token from the JSON response */
if (!(json = json_parse(buffer)))
LOG_ERROR_AND_GOTO(fail, "Failed to parse JSON response");
if (!json_object_get_prop(json, "access_token", &prop) || !json_string_get(prop, &token))
LOG_ERROR_AND_GOTO(fail, "Could not find \"access_token\" property in JSON response");
if (!(aad->access_token = _strdup(token)))
goto fail;
XFREE(buffer);
json_free(json);
json = NULL;
/* Send the nonce request message */
WLog_DBG(TAG, "HTTP nonce request: %s", nonce_http_request);
ERR_clear_error();
if (BIO_write(bio, nonce_http_request, strlen(nonce_http_request)) < 0)
{
ERR_print_errors_cb(print_error, NULL);
goto fail;
}
/* Read in the response */
if (!read_http_message(bio, &status_code, &buffer, &length))
LOG_ERROR_AND_GOTO(fail, "Unable to read HTTP response");
WLog_DBG(TAG, "HTTP nonce response: %s", buffer);
if (status_code != 200)
LOG_ERROR_AND_GOTO(fail, "Received status code: %li", status_code);
/* Extract the nonce from the response */
if (!(json = json_parse(buffer)))
LOG_ERROR_AND_GOTO(fail, "Failed to parse JSON response");
if (!json_object_get_prop(json, "Nonce", &prop) || !json_string_get(prop, &token))
LOG_ERROR_AND_GOTO(fail, "Could not find \"Nonce\" property in JSON response");
if (!(aad->nonce = _strdup(token)))
goto fail;
ret = 1;
fail:
json_free(json);
free(buffer);
free(req_body);
free(req_header);
BIO_free_all(bio);
SSL_CTX_free(ssl_ctx);
free(auth_code);
return ret;
}
static int aad_send_auth_request(rdpAad* aad, const char* ts_nonce)
{
int ret = -1;
char* jws_header = NULL;
char* jws_payload = NULL;
char* jws_signature = NULL;
char* buffer = NULL;
wStream* s = NULL;
time_t ts = time(NULL);
char *e = NULL, *n = NULL;
size_t length = 0;
EVP_MD_CTX* md_ctx = NULL;
WINPR_ASSERT(aad);
WINPR_ASSERT(ts_nonce);
/* Construct the base64url encoded JWS header */
if ((length = alloc_sprintf(&buffer, "{\"alg\":\"RS256\",\"kid\":\"%s\"}", aad->kid)) < 0)
goto fail;
if (!(jws_header = crypto_base64url_encode((BYTE*)buffer, strlen(buffer))))
goto fail;
XFREE(buffer);
if (!get_encoded_rsa_params(aad->pop_key, &e, &n))
LOG_ERROR_AND_GOTO(fail, "Error getting RSA key params");
/* Construct the base64url encoded JWS payload */
length = alloc_sprintf(&buffer,
"{"
"\"ts\":\"%li\","
"\"at\":\"%s\","
"\"u\":\"ms-device-service://termsrv.wvd.microsoft.com/name/%s\","
"\"nonce\":\"%s\","
"\"cnf\":{\"jwk\":{\"kty\":\"RSA\",\"e\":\"%s\",\"n\":\"%s\"}},"
"\"client_claims\":\"{\\\"aad_nonce\\\":\\\"%s\\\"}\""
"}",
ts, aad->access_token, aad->hostname, ts_nonce, e, n, aad->nonce);
if (length < 0)
goto fail;
if (!(jws_payload = crypto_base64url_encode((BYTE*)buffer, strlen(buffer))))
goto fail;
XFREE(buffer);
/* Sign the JWS with the pop key */
if (!(md_ctx = EVP_MD_CTX_new()))
goto fail;
if (!(EVP_DigestSignInit(md_ctx, NULL, EVP_sha256(), NULL, aad->pop_key)))
LOG_ERROR_AND_GOTO(fail, "Error while initializing signature context");
if (!(EVP_DigestSignUpdate(md_ctx, jws_header, strlen(jws_header))))
LOG_ERROR_AND_GOTO(fail, "Error while signing data");
if (!(EVP_DigestSignUpdate(md_ctx, ".", 1)))
LOG_ERROR_AND_GOTO(fail, "Error while signing data");
if (!(EVP_DigestSignUpdate(md_ctx, jws_payload, strlen(jws_payload))))
LOG_ERROR_AND_GOTO(fail, "Error while signing data");
if (!(EVP_DigestSignFinal(md_ctx, NULL, &length)))
LOG_ERROR_AND_GOTO(fail, "Error while signing data");
if (!(buffer = malloc(length)))
goto fail;
if (!(EVP_DigestSignFinal(md_ctx, (BYTE*)buffer, &length)))
LOG_ERROR_AND_GOTO(fail, "Error while signing data");
if (!(jws_signature = crypto_base64url_encode((BYTE*)buffer, length)))
goto fail;
/* Construct the Authentication Request PDU with the JWS as the RDP Assertion */
length = _snprintf(NULL, 0, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header, jws_payload,
jws_signature) +
1;
if (length < 0)
goto fail;
if (!(s = Stream_New(NULL, length)))
goto fail;
_snprintf(Stream_PointerAs(s, char), length, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header,
jws_payload, jws_signature);
Stream_Seek(s, length);
if (transport_write(aad->transport, s) < 0)
LOG_ERROR_AND_GOTO(fail, "Failed to send Authentication Request PDU");
ret = 1;
aad->state = AAD_STATE_AUTH;
fail:
Stream_Free(s, TRUE);
free(e);
free(n);
free(buffer);
free(jws_header);
free(jws_payload);
free(jws_signature);
EVP_MD_CTX_free(md_ctx);
return ret;
}
int aad_recv(rdpAad* aad, wStream* s)
{
JSON* json;
JSON* prop;
WINPR_ASSERT(aad);
WINPR_ASSERT(s);
if (aad->state == AAD_STATE_INITIAL)
{
const char* ts_nonce = NULL;
int ret = 0;
if (!(json = json_parse(Stream_PointerAs(s, char))))
LOG_ERROR_AND_RETURN(-1, "Failed to parse Server Nonce PDU");
if (!json_object_get_prop(json, "ts_nonce", &prop) || !json_string_get(prop, &ts_nonce))
{
json_free(json);
WLog_ERR(TAG, "Failed to find ts_nonce in PDU");
return -1;
}
Stream_Seek(s, Stream_Length(s));
ret = aad_send_auth_request(aad, ts_nonce);
json_free(json);
return ret;
}
else if (aad->state == AAD_STATE_AUTH)
{
double result = 0;
if (!(json = json_parse(Stream_PointerAs(s, char))))
LOG_ERROR_AND_RETURN(-1, "Failed to parse Authentication Result PDU");
if (!json_object_get_prop(json, "authentication_result", &prop) ||
!json_number_get(prop, &result))
{
json_free(json);
WLog_ERR(TAG, "Failed to find authentication_result in PDU");
return -1;
}
json_free(json);
Stream_Seek(s, Stream_Length(s));
if (result != 0)
LOG_ERROR_AND_RETURN(-1, "Authentication result: %d", (int)result);
aad->state = AAD_STATE_FINAL;
return 1;
}
else
LOG_ERROR_AND_RETURN(-1, "Invalid state");
}
enum AAD_STATE aad_get_state(rdpAad* aad)
{
if (!aad)
return AAD_STATE_FINAL;
return aad->state;
}
void aad_free(rdpAad* aad)
{
if (!aad)
return;
free(aad->hostname);
free(aad->nonce);
free(aad->access_token);
free(aad->kid);
EVP_PKEY_free(aad->pop_key);
free(aad);
}
static BOOL read_http_message(BIO* bio, long* status_code, char** content, size_t* content_length)
{
char buffer[1024] = { 0 };
WINPR_ASSERT(status_code);
WINPR_ASSERT(content);
WINPR_ASSERT(content_length);
if (BIO_get_line(bio, buffer, sizeof(buffer)) <= 0)
LOG_ERROR_AND_RETURN(FALSE, "Error reading HTTP response");
if (sscanf(buffer, "HTTP/%*u.%*u %li %*[^\r\n]\r\n", status_code) < 1)
LOG_ERROR_AND_RETURN(FALSE, "Invalid HTTP response status line");
do
{
char* name = NULL;
char* val = NULL;
if (BIO_get_line(bio, buffer, sizeof(buffer)) <= 0)
LOG_ERROR_AND_RETURN(FALSE, "Error reading HTTP response");
name = strtok_r(buffer, ":", &val);
if (name && _stricmp(name, "content-length") == 0)
*content_length = strtoul(val, NULL, 10);
} while (strcmp(buffer, "\r\n") != 0);
if (*content_length == 0)
return TRUE;
if (!(*content = malloc(*content_length + 1)))
return FALSE;
(*content)[*content_length] = '\0';
if (BIO_read(bio, *content, *content_length) < *content_length)
{
free(*content);
LOG_ERROR_AND_RETURN(FALSE, "Error reading HTTP response body");
}
return TRUE;
}
static BOOL generate_pop_key(rdpAad* aad)
{
EVP_PKEY_CTX* ctx = NULL;
BOOL ret = FALSE;
size_t length = 0;
char* buffer = NULL;
char *e = NULL, *n = NULL;
WINPR_DIGEST_CTX* digest = NULL;
BYTE hash[WINPR_SHA256_DIGEST_LENGTH] = { 0 };
WINPR_ASSERT(aad);
/* Generate a 2048-bit RSA key pair */
if (!(ctx = EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, NULL)))
return FALSE;
if (EVP_PKEY_keygen_init(ctx) <= 0)
LOG_ERROR_AND_GOTO(fail, "Error initializing keygen");
if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx, 2048) <= 0)
LOG_ERROR_AND_GOTO(fail, "Error setting RSA keygen bits");
if (EVP_PKEY_keygen(ctx, &aad->pop_key) <= 0)
LOG_ERROR_AND_GOTO(fail, "Error generating RSA pop token key");
/* Encode the public key as a JWK */
if (!get_encoded_rsa_params(aad->pop_key, &e, &n))
LOG_ERROR_AND_GOTO(fail, "Error getting RSA key params");
if ((length = alloc_sprintf(&buffer, "{\"e\":\"%s\",\"kty\":\"RSA\",\"n\":\"%s\"}", e, n)) < 0)
goto fail;
/* Hash the encoded public key */
if (!(digest = winpr_Digest_New()))
goto fail;
if (!winpr_Digest_Init(digest, WINPR_MD_SHA256))
LOG_ERROR_AND_GOTO(fail, "Error initializing SHA256 digest");
if (!winpr_Digest_Update(digest, (BYTE*)buffer, length))
LOG_ERROR_AND_GOTO(fail, "Unable to get hash of JWK");
if (!winpr_Digest_Final(digest, hash, WINPR_SHA256_DIGEST_LENGTH))
LOG_ERROR_AND_GOTO(fail, "Unable to get hash of JWK");
XFREE(buffer);
/* Base64url encode the hash */
if (!(buffer = crypto_base64url_encode(hash, WINPR_SHA256_DIGEST_LENGTH)))
goto fail;
/* Encode a JSON object with a single property "kid" whose value is the encoded hash */
{
char* buf2 = NULL;
if ((length = alloc_sprintf(&buf2, "{\"kid\":\"%s\"}", buffer)) < 0)
goto fail;
free(buffer);
buffer = buf2;
}
/* Finally, base64url encode the JSON text to form the kid */
if (!(aad->kid = crypto_base64url_encode((BYTE*)buffer, length)))
LOG_ERROR_AND_GOTO(fail, "Error base64url encoding kid");
ret = TRUE;
fail:
free(buffer);
free(e);
free(n);
winpr_Digest_Free(digest);
EVP_PKEY_CTX_free(ctx);
return ret;
}
static BOOL get_encoded_rsa_params(EVP_PKEY* pkey, char** e, char** n)
{
BIGNUM *bn_e = NULL, *bn_n = NULL;
BYTE buf[2048];
size_t length = 0;
*e = NULL;
*n = NULL;
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
if (!EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_E, &bn_e))
goto fail;
if (!EVP_PKEY_get_bn_param(pkey, OSSL_PKEY_PARAM_RSA_N, &bn_n))
goto fail;
#else
{
const RSA* rsa = NULL;
if (!(rsa = EVP_PKEY_get0_RSA(pkey)))
goto fail;
if (!(bn_e = BN_dup(RSA_get0_e(rsa))))
goto fail;
if (!(bn_n = BN_dup(RSA_get0_n(rsa))))
goto fail;
}
#endif
length = BN_num_bytes(bn_e);
if (length > sizeof(buf))
goto fail;
if (BN_bn2bin(bn_e, buf) < length)
goto fail;
*e = crypto_base64url_encode(buf, length);
length = BN_num_bytes(bn_n);
if (length > sizeof(buf))
goto fail;
if (BN_bn2bin(bn_n, buf) < length)
goto fail;
*n = crypto_base64url_encode(buf, length);
fail:
BN_free(bn_e);
BN_free(bn_n);
if (!(*e) || !(*n))
{
free(*e);
free(*n);
return FALSE;
}
return TRUE;
}
static int alloc_sprintf(char** s, const char* template, ...)
{
int length;
va_list ap;
WINPR_ASSERT(s);
*s = NULL;
va_start(ap, template);
length = vsnprintf(NULL, 0, template, ap);
va_end(ap);
if (!(*s = malloc(length + 1)))
return -1;
va_start(ap, template);
vsprintf(*s, template, ap);
va_end(ap);
return length;
}

43
libfreerdp/core/aad.h Normal file
View File

@ -0,0 +1,43 @@
/**
* FreeRDP: A Remote Desktop Protocol Implementation
* Network Level Authentication (NLA)
*
* Copyright 2023 Isaac Klein <fifthdegree@protonmail.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FREERDP_LIB_CORE_AAD_H
#define FREERDP_LIB_CORE_AAD_H
typedef struct rdp_aad rdpAad;
enum AAD_STATE
{
AAD_STATE_INITIAL,
AAD_STATE_AUTH,
AAD_STATE_FINAL
};
#include <freerdp/api.h>
#include <freerdp/freerdp.h>
FREERDP_LOCAL int aad_client_begin(rdpAad* aad);
FREERDP_LOCAL int aad_recv(rdpAad* aad, wStream* s);
FREERDP_LOCAL enum AAD_STATE aad_get_state(rdpAad* aad);
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport);
FREERDP_LOCAL void aad_free(rdpAad* aad);
#endif /* FREERDP_LIB_CORE_AAD_H */

View File

@ -393,6 +393,7 @@ BOOL rdp_client_connect(rdpRdp* rdp)
nego_enable_nla(rdp->nego, settings->NlaSecurity); nego_enable_nla(rdp->nego, settings->NlaSecurity);
nego_enable_ext(rdp->nego, settings->ExtSecurity); nego_enable_ext(rdp->nego, settings->ExtSecurity);
nego_enable_rdstls(rdp->nego, settings->RdstlsSecurity); nego_enable_rdstls(rdp->nego, settings->RdstlsSecurity);
nego_enable_aad(rdp->nego, settings->AadSecurity);
if (settings->MstscCookieMode) if (settings->MstscCookieMode)
settings->CookieMaxLength = MSTSC_COOKIE_MAX_LENGTH; settings->CookieMaxLength = MSTSC_COOKIE_MAX_LENGTH;
@ -1766,6 +1767,8 @@ const char* rdp_state_string(CONNECTION_STATE state)
return "CONNECTION_STATE_NEGO"; return "CONNECTION_STATE_NEGO";
case CONNECTION_STATE_NLA: case CONNECTION_STATE_NLA:
return "CONNECTION_STATE_NLA"; return "CONNECTION_STATE_NLA";
case CONNECTION_STATE_AAD:
return "CONNECTION_STATE_AAD";
case CONNECTION_STATE_MCS_CREATE_REQUEST: case CONNECTION_STATE_MCS_CREATE_REQUEST:
return "CONNECTION_STATE_MCS_CREATE_REQUEST"; return "CONNECTION_STATE_MCS_CREATE_REQUEST";
case CONNECTION_STATE_MCS_CREATE_RESPONSE: case CONNECTION_STATE_MCS_CREATE_RESPONSE:

View File

@ -57,7 +57,7 @@ struct rdp_nego
UINT32 SelectedProtocol; UINT32 SelectedProtocol;
UINT32 RequestedProtocols; UINT32 RequestedProtocols;
BOOL NegotiateSecurityLayer; BOOL NegotiateSecurityLayer;
BOOL EnabledProtocols[16]; BOOL EnabledProtocols[32];
BOOL RestrictedAdminModeRequired; BOOL RestrictedAdminModeRequired;
BOOL GatewayEnabled; BOOL GatewayEnabled;
BOOL GatewayBypassLocal; BOOL GatewayBypassLocal;
@ -68,10 +68,10 @@ struct rdp_nego
static const char* nego_state_string(NEGO_STATE state) static const char* nego_state_string(NEGO_STATE state)
{ {
static const char* const NEGO_STATE_STRINGS[] = { "NEGO_STATE_INITIAL", "NEGO_STATE_RDSTLS", static const char* const NEGO_STATE_STRINGS[] = { "NEGO_STATE_INITIAL", "NEGO_STATE_RDSTLS",
"NEGO_STATE_EXT", "NEGO_STATE_NLA", "NEGO_STATE_AAD", "NEGO_STATE_EXT",
"NEGO_STATE_TLS", "NEGO_STATE_RDP", "NEGO_STATE_NLA", "NEGO_STATE_TLS",
"NEGO_STATE_FAIL", "NEGO_STATE_FINAL", "NEGO_STATE_RDP", "NEGO_STATE_FAIL",
"NEGO_STATE_INVALID" }; "NEGO_STATE_FINAL", "NEGO_STATE_INVALID" };
if (state >= ARRAYSIZE(NEGO_STATE_STRINGS)) if (state >= ARRAYSIZE(NEGO_STATE_STRINGS))
return NEGO_STATE_STRINGS[ARRAYSIZE(NEGO_STATE_STRINGS) - 1]; return NEGO_STATE_STRINGS[ARRAYSIZE(NEGO_STATE_STRINGS) - 1];
return NEGO_STATE_STRINGS[state]; return NEGO_STATE_STRINGS[state];
@ -80,7 +80,9 @@ static const char* nego_state_string(NEGO_STATE state)
static const char* protocol_security_string(UINT32 security) static const char* protocol_security_string(UINT32 security)
{ {
static const char* PROTOCOL_SECURITY_STRINGS[] = { "RDP", "TLS", "NLA", "UNK", "RDSTLS", static const char* PROTOCOL_SECURITY_STRINGS[] = { "RDP", "TLS", "NLA", "UNK", "RDSTLS",
"UNK", "UNK", "UNK", "EXT", "UNK" }; "UNK", "UNK", "UNK", "EXT", "UNK",
"UNK", "UNK", "UNK", "UNK", "UNK",
"UNK", "AAD", "UNK", "UNK", "UNK" };
if (security >= ARRAYSIZE(PROTOCOL_SECURITY_STRINGS)) if (security >= ARRAYSIZE(PROTOCOL_SECURITY_STRINGS))
return PROTOCOL_SECURITY_STRINGS[ARRAYSIZE(PROTOCOL_SECURITY_STRINGS) - 1]; return PROTOCOL_SECURITY_STRINGS[ARRAYSIZE(PROTOCOL_SECURITY_STRINGS) - 1];
return PROTOCOL_SECURITY_STRINGS[security]; return PROTOCOL_SECURITY_STRINGS[security];
@ -116,7 +118,11 @@ BOOL nego_connect(rdpNego* nego)
if (nego_get_state(nego) == NEGO_STATE_INITIAL) if (nego_get_state(nego) == NEGO_STATE_INITIAL)
{ {
if (nego->EnabledProtocols[PROTOCOL_RDSTLS]) if (nego->EnabledProtocols[PROTOCOL_RDSAAD])
{
nego_set_state(nego, NEGO_STATE_AAD);
}
else if (nego->EnabledProtocols[PROTOCOL_RDSTLS])
{ {
nego_set_state(nego, NEGO_STATE_RDSTLS); nego_set_state(nego, NEGO_STATE_RDSTLS);
} }
@ -147,6 +153,7 @@ BOOL nego_connect(rdpNego* nego)
{ {
WLog_DBG(TAG, "Security Layer Negotiation is disabled"); WLog_DBG(TAG, "Security Layer Negotiation is disabled");
/* attempt only the highest enabled protocol (see nego_attempt_*) */ /* attempt only the highest enabled protocol (see nego_attempt_*) */
nego->EnabledProtocols[PROTOCOL_RDSAAD] = FALSE;
nego->EnabledProtocols[PROTOCOL_HYBRID] = FALSE; nego->EnabledProtocols[PROTOCOL_HYBRID] = FALSE;
nego->EnabledProtocols[PROTOCOL_SSL] = FALSE; nego->EnabledProtocols[PROTOCOL_SSL] = FALSE;
nego->EnabledProtocols[PROTOCOL_RDP] = FALSE; nego->EnabledProtocols[PROTOCOL_RDP] = FALSE;
@ -155,6 +162,10 @@ BOOL nego_connect(rdpNego* nego)
switch (nego_get_state(nego)) switch (nego_get_state(nego))
{ {
case NEGO_STATE_AAD:
nego->EnabledProtocols[PROTOCOL_RDSAAD] = TRUE;
nego->SelectedProtocol = PROTOCOL_RDSAAD;
break;
case NEGO_STATE_RDSTLS: case NEGO_STATE_RDSTLS:
nego->EnabledProtocols[PROTOCOL_RDSTLS] = TRUE; nego->EnabledProtocols[PROTOCOL_RDSTLS] = TRUE;
nego->SelectedProtocol = PROTOCOL_RDSTLS; nego->SelectedProtocol = PROTOCOL_RDSTLS;
@ -272,7 +283,12 @@ BOOL nego_security_connect(rdpNego* nego)
} }
else if (!nego->SecurityConnected) else if (!nego->SecurityConnected)
{ {
if (nego->SelectedProtocol == PROTOCOL_RDSTLS) if (nego->SelectedProtocol == PROTOCOL_RDSAAD)
{
WLog_DBG(TAG, "nego_security_connect with PROTOCOL_RDSAAD");
nego->SecurityConnected = transport_connect_aad(nego->transport);
}
else if (nego->SelectedProtocol == PROTOCOL_RDSTLS)
{ {
WLog_DBG(TAG, "nego_security_connect with PROTOCOL_RDSTLS"); WLog_DBG(TAG, "nego_security_connect with PROTOCOL_RDSTLS");
nego->SecurityConnected = transport_connect_rdstls(nego->transport); nego->SecurityConnected = transport_connect_rdstls(nego->transport);
@ -500,6 +516,49 @@ static void nego_attempt_rdstls(rdpNego* nego)
} }
} }
static void nego_attempt_rdsaad(rdpNego* nego)
{
WINPR_ASSERT(nego);
nego->RequestedProtocols = PROTOCOL_RDSAAD;
WLog_DBG(TAG, "Attempting RDS AAD Auth security");
if (!nego_transport_connect(nego))
{
nego_set_state(nego, NEGO_STATE_FAIL);
return;
}
if (!nego_send_negotiation_request(nego))
{
nego_set_state(nego, NEGO_STATE_FAIL);
return;
}
if (!nego_recv_response(nego))
{
nego_set_state(nego, NEGO_STATE_FAIL);
return;
}
WLog_DBG(TAG, "state: %s", nego_state_string(nego_get_state(nego)));
if (nego_get_state(nego) != NEGO_STATE_FINAL)
{
nego_transport_disconnect(nego);
if (nego->EnabledProtocols[PROTOCOL_HYBRID_EX])
nego_set_state(nego, NEGO_STATE_EXT);
else if (nego->EnabledProtocols[PROTOCOL_HYBRID])
nego_set_state(nego, NEGO_STATE_NLA);
else if (nego->EnabledProtocols[PROTOCOL_SSL])
nego_set_state(nego, NEGO_STATE_TLS);
else if (nego->EnabledProtocols[PROTOCOL_RDP])
nego_set_state(nego, NEGO_STATE_RDP);
else
nego_set_state(nego, NEGO_STATE_FAIL);
}
}
static void nego_attempt_ext(rdpNego* nego) static void nego_attempt_ext(rdpNego* nego)
{ {
WINPR_ASSERT(nego); WINPR_ASSERT(nego);
@ -720,6 +779,11 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
if (nego->SelectedProtocol) if (nego->SelectedProtocol)
{ {
if ((nego->SelectedProtocol == PROTOCOL_RDSAAD) &&
(!nego->EnabledProtocols[PROTOCOL_RDSAAD]))
{
nego_set_state(nego, NEGO_STATE_FAIL);
}
if ((nego->SelectedProtocol == PROTOCOL_HYBRID) && if ((nego->SelectedProtocol == PROTOCOL_HYBRID) &&
(!nego->EnabledProtocols[PROTOCOL_HYBRID])) (!nego->EnabledProtocols[PROTOCOL_HYBRID]))
{ {
@ -921,6 +985,9 @@ void nego_send(rdpNego* nego)
switch (nego_get_state(nego)) switch (nego_get_state(nego))
{ {
case NEGO_STATE_AAD:
nego_attempt_rdsaad(nego);
break;
case NEGO_STATE_RDSTLS: case NEGO_STATE_RDSTLS:
nego_attempt_rdstls(nego); nego_attempt_rdstls(nego);
break; break;
@ -1620,6 +1687,19 @@ void nego_enable_ext(rdpNego* nego, BOOL enable_ext)
nego->EnabledProtocols[PROTOCOL_HYBRID_EX] = enable_ext; nego->EnabledProtocols[PROTOCOL_HYBRID_EX] = enable_ext;
} }
/**
* Enable RDS AAD security protocol.
* @param nego A pointer to the NEGO struct pointer to the negotiation structure
* @param enable_ext whether to enable RDS AAD Auth protocol (TRUE for
* enabled, FALSE for disabled)
*/
void nego_enable_aad(rdpNego* nego, BOOL enable_aad)
{
WLog_DBG(TAG, "Enabling RDS AAD security: %s", enable_aad ? "TRUE" : "FALSE");
nego->EnabledProtocols[PROTOCOL_RDSAAD] = enable_aad;
}
/** /**
* Set routing token. * Set routing token.
* @param nego A pointer to the NEGO struct * @param nego A pointer to the NEGO struct

View File

@ -37,6 +37,7 @@
#define PROTOCOL_HYBRID 0x00000002 #define PROTOCOL_HYBRID 0x00000002
#define PROTOCOL_RDSTLS 0x00000004 #define PROTOCOL_RDSTLS 0x00000004
#define PROTOCOL_HYBRID_EX 0x00000008 #define PROTOCOL_HYBRID_EX 0x00000008
#define PROTOCOL_RDSAAD 0x00000010
#define PROTOCOL_FAILED_NEGO 0x80000000 /* only used internally, not on the wire */ #define PROTOCOL_FAILED_NEGO 0x80000000 /* only used internally, not on the wire */
@ -59,6 +60,7 @@ typedef enum
{ {
NEGO_STATE_INITIAL, NEGO_STATE_INITIAL,
NEGO_STATE_RDSTLS, /* RDSTLS (TLS implicit) */ NEGO_STATE_RDSTLS, /* RDSTLS (TLS implicit) */
NEGO_STATE_AAD, /* Azure AD Authentication (TLS implicit) */
NEGO_STATE_EXT, /* Extended NLA (NLA + TLS implicit) */ NEGO_STATE_EXT, /* Extended NLA (NLA + TLS implicit) */
NEGO_STATE_NLA, /* Network Level Authentication (TLS implicit) */ NEGO_STATE_NLA, /* Network Level Authentication (TLS implicit) */
NEGO_STATE_TLS, /* TLS Encryption without NLA */ NEGO_STATE_TLS, /* TLS Encryption without NLA */
@ -122,6 +124,7 @@ FREERDP_LOCAL void nego_enable_rdp(rdpNego* nego, BOOL enable_rdp);
FREERDP_LOCAL void nego_enable_tls(rdpNego* nego, BOOL enable_tls); FREERDP_LOCAL void nego_enable_tls(rdpNego* nego, BOOL enable_tls);
FREERDP_LOCAL void nego_enable_nla(rdpNego* nego, BOOL enable_nla); FREERDP_LOCAL void nego_enable_nla(rdpNego* nego, BOOL enable_nla);
FREERDP_LOCAL void nego_enable_rdstls(rdpNego* nego, BOOL enable_rdstls); FREERDP_LOCAL void nego_enable_rdstls(rdpNego* nego, BOOL enable_rdstls);
FREERDP_LOCAL void nego_enable_aad(rdpNego* nego, BOOL enable_aad);
FREERDP_LOCAL void nego_enable_ext(rdpNego* nego, BOOL enable_ext); FREERDP_LOCAL void nego_enable_ext(rdpNego* nego, BOOL enable_ext);
FREERDP_LOCAL const BYTE* nego_get_routing_token(rdpNego* nego, DWORD* RoutingTokenLength); FREERDP_LOCAL const BYTE* nego_get_routing_token(rdpNego* nego, DWORD* RoutingTokenLength);
FREERDP_LOCAL BOOL nego_set_routing_token(rdpNego* nego, const BYTE* RoutingToken, FREERDP_LOCAL BOOL nego_set_routing_token(rdpNego* nego, const BYTE* RoutingToken,

View File

@ -1783,6 +1783,25 @@ static state_run_t rdp_recv_callback_int(rdpTransport* transport, wStream* s, vo
} }
break; break;
case CONNECTION_STATE_AAD:
if (aad_recv(rdp->aad, s) < 1)
{
WLog_ERR(TAG, "%s - aad_recv() fail", rdp_get_state_string(rdp));
status = STATE_RUN_FAILED;
}
if (state_run_success(status))
{
if (aad_get_state(rdp->aad) == AAD_STATE_FINAL)
{
transport_set_aad_mode(rdp->transport, FALSE);
if (!rdp_client_transition_to_state(rdp, CONNECTION_STATE_MCS_CREATE_REQUEST))
status = STATE_RUN_FAILED;
else
status = STATE_RUN_CONTINUE;
}
}
break;
case CONNECTION_STATE_MCS_CREATE_REQUEST: case CONNECTION_STATE_MCS_CREATE_REQUEST:
if (!mcs_client_begin(rdp->mcs)) if (!mcs_client_begin(rdp->mcs))
{ {
@ -2335,6 +2354,7 @@ void rdp_free(rdpRdp* rdp)
PubSub_Free(rdp->pubSub); PubSub_Free(rdp->pubSub);
if (rdp->abortEvent) if (rdp->abortEvent)
CloseHandle(rdp->abortEvent); CloseHandle(rdp->abortEvent);
aad_free(rdp->aad);
free(rdp); free(rdp);
} }
} }

View File

@ -24,6 +24,7 @@
#include <freerdp/config.h> #include <freerdp/config.h>
#include "nla.h" #include "nla.h"
#include "aad.h"
#include "mcs.h" #include "mcs.h"
#include "tpkt.h" #include "tpkt.h"
#include "../codec/bulk.h" #include "../codec/bulk.h"
@ -146,6 +147,7 @@ struct rdp_rdp
CONNECTION_STATE state; CONNECTION_STATE state;
rdpContext* context; rdpContext* context;
rdpNla* nla; rdpNla* nla;
rdpAad* aad;
rdpMcs* mcs; rdpMcs* mcs;
rdpNego* nego; rdpNego* nego;
rdpBulk* bulk; rdpBulk* bulk;

View File

@ -384,6 +384,7 @@ rdpSettings* freerdp_settings_new(DWORD flags)
!freerdp_settings_set_bool(settings, FreeRDP_Decorations, TRUE) || !freerdp_settings_set_bool(settings, FreeRDP_Decorations, TRUE) ||
!freerdp_settings_set_uint32(settings, FreeRDP_RdpVersion, RDP_VERSION_10_11) || !freerdp_settings_set_uint32(settings, FreeRDP_RdpVersion, RDP_VERSION_10_11) ||
!freerdp_settings_set_uint32(settings, FreeRDP_ColorDepth, 16) || !freerdp_settings_set_uint32(settings, FreeRDP_ColorDepth, 16) ||
!freerdp_settings_set_bool(settings, FreeRDP_AadSecurity, FALSE) ||
!freerdp_settings_set_bool(settings, FreeRDP_ExtSecurity, FALSE) || !freerdp_settings_set_bool(settings, FreeRDP_ExtSecurity, FALSE) ||
!freerdp_settings_set_bool(settings, FreeRDP_NlaSecurity, TRUE) || !freerdp_settings_set_bool(settings, FreeRDP_NlaSecurity, TRUE) ||
!freerdp_settings_set_bool(settings, FreeRDP_TlsSecurity, TRUE) || !freerdp_settings_set_bool(settings, FreeRDP_TlsSecurity, TRUE) ||

View File

@ -3,6 +3,7 @@
#define have_bool_list_indices #define have_bool_list_indices
static const size_t bool_list_indices[] = { static const size_t bool_list_indices[] = {
FreeRDP_AadSecurity,
FreeRDP_AllowCacheWaitingList, FreeRDP_AllowCacheWaitingList,
FreeRDP_AllowDesktopComposition, FreeRDP_AllowDesktopComposition,
FreeRDP_AllowFontSmoothing, FreeRDP_AllowFontSmoothing,
@ -117,6 +118,7 @@ static const size_t bool_list_indices[] = {
FreeRDP_PrintReconnectCookie, FreeRDP_PrintReconnectCookie,
FreeRDP_PromptForCredentials, FreeRDP_PromptForCredentials,
FreeRDP_RdpSecurity, FreeRDP_RdpSecurity,
FreeRDP_RdstlsSecurity,
FreeRDP_RedirectClipboard, FreeRDP_RedirectClipboard,
FreeRDP_RedirectDrives, FreeRDP_RedirectDrives,
FreeRDP_RedirectHomeDrive, FreeRDP_RedirectHomeDrive,

View File

@ -74,6 +74,7 @@ struct rdp_transport
HANDLE connectedEvent; HANDLE connectedEvent;
BOOL NlaMode; BOOL NlaMode;
BOOL RdstlsMode; BOOL RdstlsMode;
BOOL AadMode;
BOOL blocking; BOOL blocking;
BOOL GatewayEnabled; BOOL GatewayEnabled;
CRITICAL_SECTION ReadLock; CRITICAL_SECTION ReadLock;
@ -402,6 +403,50 @@ fail:
return rc; return rc;
} }
BOOL transport_connect_aad(rdpTransport* transport)
{
rdpContext* context = NULL;
rdpSettings* settings = NULL;
rdpRdp* rdp = NULL;
if (!transport)
return FALSE;
context = transport_get_context(transport);
WINPR_ASSERT(context);
settings = context->settings;
WINPR_ASSERT(settings);
rdp = context->rdp;
WINPR_ASSERT(rdp);
if (!transport_connect_tls(transport))
return FALSE;
if (!settings->Authentication)
return TRUE;
aad_free(rdp->aad);
rdp->aad = aad_new(context, transport);
if (!rdp->aad)
return FALSE;
transport_set_aad_mode(transport, TRUE);
if (aad_client_begin(rdp->aad) < 0)
{
WLog_Print(transport->log, WLOG_ERROR, "AAD begin failed");
freerdp_set_last_error_if_not(context, FREERDP_ERROR_AUTHENTICATION_FAILED);
transport_set_aad_mode(transport, FALSE);
return FALSE;
}
return rdp_client_transition_to_state(rdp, CONNECTION_STATE_AAD);
}
BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 port, DWORD timeout) BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 port, DWORD timeout)
{ {
int sockfd; int sockfd;
@ -938,39 +983,61 @@ static int transport_default_read_pdu(rdpTransport* transport, wStream* s)
WINPR_ASSERT(transport); WINPR_ASSERT(transport);
WINPR_ASSERT(s); WINPR_ASSERT(s);
/* Read in pdu length */ /* RDS AAD Auth PDUs have no length indicator. We need to determine the end of the PDU by
status = transport_parse_pdu(transport, s, &incomplete); * reading in one byte at a time until we encounter the terminating null byte */
while ((status == 0) && incomplete) if (transport->AadMode)
{ {
BYTE c;
int rc; int rc;
if (!Stream_EnsureRemainingCapacity(s, 1)) while (1)
return -1; {
rc = transport_read_layer_bytes(transport, s, 1); rc = transport_read_layer(transport, &c, 1);
if (rc != 1) if (rc != 1)
return rc; return rc;
status = transport_parse_pdu(transport, s, &incomplete); if (!Stream_EnsureRemainingCapacity(s, 1))
return -1;
Stream_Write(s, &c, 1);
if (c == 0)
break;
}
} }
else
{
/* Read in pdu length */
status = transport_parse_pdu(transport, s, &incomplete);
while ((status == 0) && incomplete)
{
int rc;
if (!Stream_EnsureRemainingCapacity(s, 1))
return -1;
rc = transport_read_layer_bytes(transport, s, 1);
if (rc != 1)
return rc;
status = transport_parse_pdu(transport, s, &incomplete);
}
if (status < 0) if (status < 0)
return -1; return -1;
pduLength = (size_t)status; pduLength = (size_t)status;
/* Read in rest of the PDU */ /* Read in rest of the PDU */
if (!Stream_EnsureCapacity(s, pduLength)) if (!Stream_EnsureCapacity(s, pduLength))
return -1; return -1;
position = Stream_GetPosition(s); position = Stream_GetPosition(s);
if (position > pduLength) if (position > pduLength)
return -1; return -1;
status = transport_read_layer_bytes(transport, s, pduLength - Stream_GetPosition(s)); status = transport_read_layer_bytes(transport, s, pduLength - Stream_GetPosition(s));
if (status != 1) if (status != 1)
return status; return status;
if (Stream_GetPosition(s) >= pduLength) if (Stream_GetPosition(s) >= pduLength)
WLog_Packet(transport->log, WLOG_TRACE, Stream_Buffer(s), pduLength, WLOG_PACKET_INBOUND); WLog_Packet(transport->log, WLOG_TRACE, Stream_Buffer(s), pduLength,
WLOG_PACKET_INBOUND);
}
Stream_SealLength(s); Stream_SealLength(s);
Stream_SetPosition(s, 0); Stream_SetPosition(s, 0);
@ -1323,6 +1390,12 @@ void transport_set_rdstls_mode(rdpTransport* transport, BOOL RdstlsMode)
transport->RdstlsMode = RdstlsMode; transport->RdstlsMode = RdstlsMode;
} }
void transport_set_aad_mode(rdpTransport* transport, BOOL AadMode)
{
WINPR_ASSERT(transport);
transport->AadMode = AadMode;
}
BOOL transport_disconnect(rdpTransport* transport) BOOL transport_disconnect(rdpTransport* transport)
{ {
if (!transport) if (!transport)

View File

@ -63,6 +63,7 @@ FREERDP_LOCAL BOOL transport_connect_rdp(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_connect_tls(rdpTransport* transport); FREERDP_LOCAL BOOL transport_connect_tls(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_connect_nla(rdpTransport* transport); FREERDP_LOCAL BOOL transport_connect_nla(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_connect_rdstls(rdpTransport* transport); FREERDP_LOCAL BOOL transport_connect_rdstls(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_connect_aad(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_accept_rdp(rdpTransport* transport); FREERDP_LOCAL BOOL transport_accept_rdp(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_accept_tls(rdpTransport* transport); FREERDP_LOCAL BOOL transport_accept_tls(rdpTransport* transport);
FREERDP_LOCAL BOOL transport_accept_nla(rdpTransport* transport); FREERDP_LOCAL BOOL transport_accept_nla(rdpTransport* transport);
@ -84,6 +85,7 @@ FREERDP_LOCAL BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blo
FREERDP_LOCAL void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); FREERDP_LOCAL void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
FREERDP_LOCAL void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); FREERDP_LOCAL void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
FREERDP_LOCAL void transport_set_rdstls_mode(rdpTransport* transport, BOOL RdstlsMode); FREERDP_LOCAL void transport_set_rdstls_mode(rdpTransport* transport, BOOL RdstlsMode);
FREERDP_LOCAL void transport_set_aad_mode(rdpTransport* transport, BOOL AadMode);
FREERDP_LOCAL BOOL transport_is_write_blocked(rdpTransport* transport); FREERDP_LOCAL BOOL transport_is_write_blocked(rdpTransport* transport);
FREERDP_LOCAL int transport_drain_output_buffer(rdpTransport* transport); FREERDP_LOCAL int transport_drain_output_buffer(rdpTransport* transport);