/** * FreeRDP: A Remote Desktop Protocol Implementation * Network Level Authentication (NLA) * * Copyright 2023 Isaac Klein * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include "../crypto/privatekey.h" #include #include #include #include "transport.h" #include "aad.h" struct rdp_aad { AAD_STATE state; rdpContext* rdpcontext; rdpTransport* transport; char* access_token; rdpPrivateKey* key; char* kid; char* nonce; char* hostname; char* scope; wLog* log; }; #ifdef WITH_AAD static BOOL get_encoded_rsa_params(wLog* wlog, rdpPrivateKey* key, char** e, char** n); static BOOL generate_pop_key(rdpAad* aad); WINPR_ATTR_FORMAT_ARG(2, 3) static SSIZE_T stream_sprintf(wStream* s, WINPR_FORMAT_ARG const char* fmt, ...) { va_list ap; va_start(ap, fmt); const int rc = vsnprintf(NULL, 0, fmt, ap); va_end(ap); if (rc < 0) return rc; if (!Stream_EnsureRemainingCapacity(s, (size_t)rc + 1)) return -1; char* ptr = Stream_PointerAs(s, char); va_start(ap, fmt); const int rc2 = vsnprintf(ptr, rc + 1, fmt, ap); va_end(ap); if (rc != rc2) return -23; if (!Stream_SafeSeek(s, (size_t)rc2)) return -3; return rc2; } static BOOL json_get_object(wLog* wlog, cJSON* json, const char* key, cJSON** obj) { WINPR_ASSERT(json); WINPR_ASSERT(key); if (!cJSON_HasObjectItem(json, key)) { WLog_Print(wlog, WLOG_ERROR, "[json] does not contain a key '%s'", key); return FALSE; } cJSON* prop = cJSON_GetObjectItem(json, key); if (!prop) { WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' is NULL", key); return FALSE; } *obj = prop; return TRUE; } #if defined(USE_CJSON_COMPAT) static double cJSON_GetNumberValue(const cJSON* const prop) { #ifndef NAN #ifdef _WIN32 #define NAN sqrt(-1.0) #define COMPAT_NAN_UNDEF #else #define NAN 0.0 / 0.0 #define COMPAT_NAN_UNDEF #endif #endif if (!cJSON_IsNumber(prop)) return NAN; char* val = cJSON_GetStringValue(prop); if (!val) return NAN; errno = 0; char* endptr = NULL; double dval = strtod(val, &endptr); if (val == endptr) return NAN; if (endptr != NULL) return NAN; if (errno != 0) return NAN; return dval; #ifdef COMPAT_NAN_UNDEF #undef NAN #endif } #endif static BOOL json_get_number(wLog* wlog, cJSON* json, const char* key, double* result) { BOOL rc = FALSE; cJSON* prop = NULL; if (!json_get_object(wlog, json, key, &prop)) return FALSE; if (!cJSON_IsNumber(prop)) { WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' is NOT a NUMBER", key); goto fail; } *result = cJSON_GetNumberValue(prop); rc = TRUE; fail: return rc; } static BOOL json_get_const_string(wLog* wlog, cJSON* json, const char* key, const char** result) { BOOL rc = FALSE; WINPR_ASSERT(result); *result = NULL; cJSON* prop = NULL; if (!json_get_object(wlog, json, key, &prop)) return FALSE; if (!cJSON_IsString(prop)) { WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' is NOT a STRING", key); goto fail; } const char* str = cJSON_GetStringValue(prop); if (!str) WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' is NULL", key); *result = str; rc = str != NULL; fail: return rc; } static BOOL json_get_string_alloc(wLog* wlog, cJSON* json, const char* key, char** result) { const char* str = NULL; if (!json_get_const_string(wlog, json, key, &str)) return FALSE; free(*result); *result = _strdup(str); if (!*result) WLog_Print(wlog, WLOG_ERROR, "[json] object for key '%s' strdup is NULL", key); return *result != NULL; } #if defined(USE_CJSON_COMPAT) cJSON* cJSON_ParseWithLength(const char* value, size_t buffer_length) { // Check for string '\0' termination. const size_t slen = strnlen(value, buffer_length); if (slen >= buffer_length) { if (value[buffer_length] != '\0') return NULL; } return cJSON_Parse(value); } #endif static INLINE const char* aad_auth_result_to_string(DWORD code) { #define ERROR_CASE(cd, x) \ if (cd == (DWORD)(x)) \ return #x; ERROR_CASE(code, S_OK) ERROR_CASE(code, SEC_E_INVALID_TOKEN) ERROR_CASE(code, E_ACCESSDENIED) ERROR_CASE(code, STATUS_LOGON_FAILURE) ERROR_CASE(code, STATUS_NO_LOGON_SERVERS) ERROR_CASE(code, STATUS_INVALID_LOGON_HOURS) ERROR_CASE(code, STATUS_INVALID_WORKSTATION) ERROR_CASE(code, STATUS_PASSWORD_EXPIRED) ERROR_CASE(code, STATUS_ACCOUNT_DISABLED) return "Unknown error"; } static BOOL aad_get_nonce(rdpAad* aad) { BOOL ret = FALSE; BYTE* response = NULL; long resp_code = 0; size_t response_length = 0; cJSON* json = NULL; if (!freerdp_http_request("https://login.microsoftonline.com/common/oauth2/v2.0/token", "grant_type=srv_challenge", &resp_code, &response, &response_length)) { WLog_Print(aad->log, WLOG_ERROR, "nonce request failed"); goto fail; } if (resp_code != HTTP_STATUS_OK) { WLog_Print(aad->log, WLOG_ERROR, "Server unwilling to provide nonce; returned status code %li", resp_code); if (response_length > 0) WLog_Print(aad->log, WLOG_ERROR, "[status message] %s", response); goto fail; } json = cJSON_ParseWithLength((const char*)response, response_length); if (!json) { WLog_Print(aad->log, WLOG_ERROR, "Failed to parse nonce response"); goto fail; } if (!json_get_string_alloc(aad->log, json, "Nonce", &aad->nonce)) goto fail; ret = TRUE; fail: free(response); cJSON_Delete(json); return ret; } int aad_client_begin(rdpAad* aad) { size_t size = 0; WINPR_ASSERT(aad); WINPR_ASSERT(aad->rdpcontext); rdpSettings* settings = aad->rdpcontext->settings; WINPR_ASSERT(settings); freerdp* instance = aad->rdpcontext->instance; WINPR_ASSERT(instance); /* Get the host part of the hostname */ const char* hostname = freerdp_settings_get_string(settings, FreeRDP_ServerHostname); if (!hostname) { WLog_Print(aad->log, WLOG_ERROR, "FreeRDP_ServerHostname == NULL"); return -1; } aad->hostname = _strdup(hostname); if (!aad->hostname) { WLog_Print(aad->log, WLOG_ERROR, "_strdup(FreeRDP_ServerHostname) == NULL"); return -1; } char* p = strchr(aad->hostname, '.'); if (p) *p = '\0'; if (winpr_asprintf(&aad->scope, &size, "ms-device-service%%3A%%2F%%2Ftermsrv.wvd.microsoft.com%%2Fname%%2F%s%%" "2Fuser_impersonation", aad->hostname) <= 0) return -1; if (!generate_pop_key(aad)) return -1; /* Obtain an oauth authorization code */ if (!instance->GetAccessToken) { WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL"); return -1; } const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token, 2, aad->scope, aad->kid); if (!arc) { WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token"); return -1; } /* Send the nonce request message */ if (!aad_get_nonce(aad)) { WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain nonce"); return -1; } return 1; } static char* aad_create_jws_header(rdpAad* aad) { WINPR_ASSERT(aad); /* Construct the base64url encoded JWS header */ char* buffer = NULL; size_t bufferlen = 0; const int length = winpr_asprintf(&buffer, &bufferlen, "{\"alg\":\"RS256\",\"kid\":\"%s\"}", aad->kid); if (length < 0) return NULL; char* jws_header = crypto_base64url_encode((const BYTE*)buffer, bufferlen); free(buffer); return jws_header; } static char* aad_create_jws_payload(rdpAad* aad, const char* ts_nonce) { const time_t ts = time(NULL); WINPR_ASSERT(aad); char* e = NULL; char* n = NULL; if (!get_encoded_rsa_params(aad->log, aad->key, &e, &n)) return NULL; /* Construct the base64url encoded JWS payload */ char* buffer = NULL; size_t bufferlen = 0; const int length = winpr_asprintf(&buffer, &bufferlen, "{" "\"ts\":\"%li\"," "\"at\":\"%s\"," "\"u\":\"ms-device-service://termsrv.wvd.microsoft.com/name/%s\"," "\"nonce\":\"%s\"," "\"cnf\":{\"jwk\":{\"kty\":\"RSA\",\"e\":\"%s\",\"n\":\"%s\"}}," "\"client_claims\":\"{\\\"aad_nonce\\\":\\\"%s\\\"}\"" "}", ts, aad->access_token, aad->hostname, ts_nonce, e, n, aad->nonce); free(e); free(n); if (length < 0) return NULL; char* jws_payload = crypto_base64url_encode((BYTE*)buffer, bufferlen); free(buffer); return jws_payload; } static BOOL aad_update_digest(rdpAad* aad, WINPR_DIGEST_CTX* ctx, const char* what) { WINPR_ASSERT(aad); WINPR_ASSERT(ctx); WINPR_ASSERT(what); const BOOL dsu1 = winpr_DigestSign_Update(ctx, what, strlen(what)); if (!dsu1) { WLog_Print(aad->log, WLOG_ERROR, "winpr_DigestSign_Update [%s] failed", what); return FALSE; } return TRUE; } static char* aad_final_digest(rdpAad* aad, WINPR_DIGEST_CTX* ctx) { char* jws_signature = NULL; WINPR_ASSERT(aad); WINPR_ASSERT(ctx); size_t siglen = 0; const int dsf = winpr_DigestSign_Final(ctx, NULL, &siglen); if (dsf <= 0) { WLog_Print(aad->log, WLOG_ERROR, "winpr_DigestSign_Final failed with %d", dsf); return FALSE; } char* buffer = calloc(siglen + 1, sizeof(char)); if (!buffer) { WLog_Print(aad->log, WLOG_ERROR, "calloc %" PRIuz " bytes failed", siglen + 1); goto fail; } size_t fsiglen = siglen; const int dsf2 = winpr_DigestSign_Final(ctx, (BYTE*)buffer, &fsiglen); if (dsf2 <= 0) { WLog_Print(aad->log, WLOG_ERROR, "winpr_DigestSign_Final failed with %d", dsf2); goto fail; } if (siglen != fsiglen) { WLog_Print(aad->log, WLOG_ERROR, "winpr_DigestSignFinal returned different sizes, first %" PRIuz " then %" PRIuz, siglen, fsiglen); goto fail; } jws_signature = crypto_base64url_encode((const BYTE*)buffer, fsiglen); fail: free(buffer); return jws_signature; } static char* aad_create_jws_signature(rdpAad* aad, const char* jws_header, const char* jws_payload) { char* jws_signature = NULL; WINPR_ASSERT(aad); WINPR_DIGEST_CTX* md_ctx = freerdp_key_digest_sign(aad->key, WINPR_MD_SHA256); if (!md_ctx) { WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_New failed"); goto fail; } if (!aad_update_digest(aad, md_ctx, jws_header)) goto fail; if (!aad_update_digest(aad, md_ctx, ".")) goto fail; if (!aad_update_digest(aad, md_ctx, jws_payload)) goto fail; jws_signature = aad_final_digest(aad, md_ctx); fail: winpr_Digest_Free(md_ctx); return jws_signature; } static int aad_send_auth_request(rdpAad* aad, const char* ts_nonce) { int ret = -1; char* jws_header = NULL; char* jws_payload = NULL; char* jws_signature = NULL; WINPR_ASSERT(aad); WINPR_ASSERT(ts_nonce); wStream* s = Stream_New(NULL, 1024); if (!s) goto fail; /* Construct the base64url encoded JWS header */ jws_header = aad_create_jws_header(aad); if (!jws_header) goto fail; /* Construct the base64url encoded JWS payload */ jws_payload = aad_create_jws_payload(aad, ts_nonce); if (!jws_payload) goto fail; /* Sign the JWS with the pop key */ jws_signature = aad_create_jws_signature(aad, jws_header, jws_payload); if (!jws_signature) goto fail; /* Construct the Authentication Request PDU with the JWS as the RDP Assertion */ if (stream_sprintf(s, "{\"rdp_assertion\":\"%s.%s.%s\"}", jws_header, jws_payload, jws_signature) < 0) goto fail; Stream_SealLength(s); if (transport_write(aad->transport, s) < 0) { WLog_Print(aad->log, WLOG_ERROR, "transport_write [%" PRIdz " bytes] failed", Stream_Length(s)); } else { ret = 1; aad->state = AAD_STATE_AUTH; } fail: Stream_Free(s, TRUE); free(jws_header); free(jws_payload); free(jws_signature); return ret; } static int aad_parse_state_initial(rdpAad* aad, wStream* s) { const char* jstr = Stream_PointerAs(s, char); const size_t jlen = Stream_GetRemainingLength(s); const char* ts_nonce = NULL; int ret = -1; cJSON* json = NULL; if (!Stream_SafeSeek(s, jlen)) goto fail; json = cJSON_ParseWithLength(jstr, jlen); if (!json) goto fail; if (!json_get_const_string(aad->log, json, "ts_nonce", &ts_nonce)) goto fail; ret = aad_send_auth_request(aad, ts_nonce); fail: cJSON_Delete(json); return ret; } static int aad_parse_state_auth(rdpAad* aad, wStream* s) { int rc = -1; double result = 0; DWORD error_code = 0; cJSON* json = NULL; const char* jstr = Stream_PointerAs(s, char); const size_t jlength = Stream_GetRemainingLength(s); if (!Stream_SafeSeek(s, jlength)) goto fail; json = cJSON_ParseWithLength(jstr, jlength); if (!json) goto fail; if (!json_get_number(aad->log, json, "authentication_result", &result)) goto fail; error_code = (DWORD)result; if (error_code != S_OK) { WLog_Print(aad->log, WLOG_ERROR, "Authentication result: %s (0x%08" PRIx32 ")", aad_auth_result_to_string(error_code), error_code); goto fail; } aad->state = AAD_STATE_FINAL; rc = 1; fail: cJSON_Delete(json); return rc; } int aad_recv(rdpAad* aad, wStream* s) { WINPR_ASSERT(aad); WINPR_ASSERT(s); switch (aad->state) { case AAD_STATE_INITIAL: return aad_parse_state_initial(aad, s); case AAD_STATE_AUTH: return aad_parse_state_auth(aad, s); case AAD_STATE_FINAL: default: WLog_Print(aad->log, WLOG_ERROR, "Invalid AAD_STATE %d", aad->state); return -1; } } static BOOL generate_rsa_2048(rdpAad* aad) { WINPR_ASSERT(aad); return freerdp_key_generate(aad->key, 2048); } static char* generate_rsa_digest_base64_str(rdpAad* aad, const char* input, size_t ilen) { char* b64 = NULL; WINPR_DIGEST_CTX* digest = winpr_Digest_New(); if (!digest) { WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_New failed"); goto fail; } if (!winpr_Digest_Init(digest, WINPR_MD_SHA256)) { WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Init(WINPR_MD_SHA256) failed"); goto fail; } if (!winpr_Digest_Update(digest, (const BYTE*)input, ilen)) { WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Update(%" PRIuz ") failed", ilen); goto fail; } BYTE hash[WINPR_SHA256_DIGEST_LENGTH] = { 0 }; if (!winpr_Digest_Final(digest, hash, sizeof(hash))) { WLog_Print(aad->log, WLOG_ERROR, "winpr_Digest_Final(%" PRIuz ") failed", sizeof(hash)); goto fail; } /* Base64url encode the hash */ b64 = crypto_base64url_encode(hash, sizeof(hash)); fail: winpr_Digest_Free(digest); return b64; } static BOOL generate_json_base64_str(rdpAad* aad, const char* b64_hash) { WINPR_ASSERT(aad); char* buffer = NULL; size_t blen = 0; const int length = winpr_asprintf(&buffer, &blen, "{\"kid\":\"%s\"}", b64_hash); if (length < 0) return FALSE; /* Finally, base64url encode the JSON text to form the kid */ free(aad->kid); aad->kid = crypto_base64url_encode((const BYTE*)buffer, (size_t)length); free(buffer); if (!aad->kid) { return FALSE; } return TRUE; } BOOL generate_pop_key(rdpAad* aad) { BOOL ret = FALSE; char* buffer = NULL; char* b64_hash = NULL; char *e = NULL, *n = NULL; WINPR_ASSERT(aad); /* Generate a 2048-bit RSA key pair */ if (!generate_rsa_2048(aad)) goto fail; /* Encode the public key as a JWK */ if (!get_encoded_rsa_params(aad->log, aad->key, &e, &n)) goto fail; size_t blen = 0; const int alen = winpr_asprintf(&buffer, &blen, "{\"e\":\"%s\",\"kty\":\"RSA\",\"n\":\"%s\"}", e, n); if (alen < 0) goto fail; /* Hash the encoded public key */ b64_hash = generate_rsa_digest_base64_str(aad, buffer, blen); if (!b64_hash) goto fail; /* Encode a JSON object with a single property "kid" whose value is the encoded hash */ ret = generate_json_base64_str(aad, b64_hash); fail: free(b64_hash); free(buffer); free(e); free(n); return ret; } static char* bn_to_base64_url(wLog* wlog, rdpPrivateKey* key, enum FREERDP_KEY_PARAM param) { WINPR_ASSERT(wlog); WINPR_ASSERT(key); size_t len = 0; char* bn = freerdp_key_get_param(key, param, &len); if (!bn) return NULL; char* b64 = (char*)crypto_base64url_encode(bn, len); free(bn); if (!b64) WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode BIGNUM"); return b64; } BOOL get_encoded_rsa_params(wLog* wlog, rdpPrivateKey* key, char** pe, char** pn) { BOOL rc = FALSE; char* e = NULL; char* n = NULL; WINPR_ASSERT(wlog); WINPR_ASSERT(key); WINPR_ASSERT(pe); WINPR_ASSERT(pn); *pe = NULL; *pn = NULL; e = bn_to_base64_url(wlog, key, FREERDP_KEY_PARAM_RSA_E); if (!e) { WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode RSA E"); goto fail; } n = bn_to_base64_url(wlog, key, FREERDP_KEY_PARAM_RSA_N); if (!n) { WLog_Print(wlog, WLOG_ERROR, "failed base64 url encode RSA N"); goto fail; } rc = TRUE; fail: if (!rc) { free(e); free(n); } else { *pe = e; *pn = n; } return rc; } #else int aad_client_begin(rdpAad* aad) { WINPR_ASSERT(aad); WLog_Print(aad->log, WLOG_ERROR, "AAD security not compiled in, aborting!"); return -1; } int aad_recv(rdpAad* aad, wStream* s) { WINPR_ASSERT(aad); WLog_Print(aad->log, WLOG_ERROR, "AAD security not compiled in, aborting!"); return -1; } #endif rdpAad* aad_new(rdpContext* context, rdpTransport* transport) { WINPR_ASSERT(transport); WINPR_ASSERT(context); rdpAad* aad = (rdpAad*)calloc(1, sizeof(rdpAad)); if (!aad) return NULL; aad->log = WLog_Get(FREERDP_TAG("aad")); aad->key = freerdp_key_new(); if (!aad->key) goto fail; aad->rdpcontext = context; aad->transport = transport; return aad; fail: aad_free(aad); return NULL; } void aad_free(rdpAad* aad) { if (!aad) return; free(aad->hostname); free(aad->scope); free(aad->nonce); free(aad->access_token); free(aad->kid); freerdp_key_free(aad->key); free(aad); } AAD_STATE aad_get_state(rdpAad* aad) { WINPR_ASSERT(aad); return aad->state; } BOOL aad_is_supported(void) { #ifdef WITH_AAD return TRUE; #else return FALSE; #endif } #ifdef WITH_AAD char* freerdp_utils_aad_get_access_token(wLog* log, const char* data, size_t length) { char* token = NULL; cJSON* access_token_prop = NULL; const char* access_token_str = NULL; cJSON* json = cJSON_ParseWithLength(data, length); if (!json) { WLog_Print(log, WLOG_ERROR, "Failed to parse access token response [got %" PRIuz " bytes", length); goto cleanup; } access_token_prop = cJSON_GetObjectItem(json, "access_token"); if (!access_token_prop) { WLog_Print(log, WLOG_ERROR, "Response has no \"access_token\" property"); goto cleanup; } access_token_str = cJSON_GetStringValue(access_token_prop); if (!access_token_str) { WLog_Print(log, WLOG_ERROR, "Invalid value for \"access_token\""); goto cleanup; } token = _strdup(access_token_str); cleanup: cJSON_Delete(json); return token; } #endif