diff --git a/libfreerdp/core/activation.c b/libfreerdp/core/activation.c index 89d8008db..13896df60 100644 --- a/libfreerdp/core/activation.c +++ b/libfreerdp/core/activation.c @@ -468,7 +468,8 @@ BOOL rdp_recv_deactivate_all(rdpRdp* rdp, wStream* s) rdp_client_transition_to_state(rdp, CONNECTION_STATE_CAPABILITIES_EXCHANGE); - for (timeout = 0; timeout < rdp->settings->TcpAckTimeout; timeout += 100) + for (timeout = 0; timeout < freerdp_settings_get_uint32(rdp->settings, FreeRDP_TcpAckTimeout); + timeout += 100) { if (rdp_check_fds(rdp) < 0) return FALSE; @@ -575,7 +576,7 @@ BOOL rdp_server_accept_client_font_list_pdu(rdpRdp* rdp, wStream* s) if (!rdp_send_server_font_map_pdu(rdp)) return FALSE; - if (rdp_server_transition_to_state(rdp, CONNECTION_STATE_ACTIVE) < 0) + if (!rdp_server_transition_to_state(rdp, CONNECTION_STATE_ACTIVE)) return FALSE; return TRUE; diff --git a/libfreerdp/core/client.c b/libfreerdp/core/client.c index 11575b094..3faa302db 100644 --- a/libfreerdp/core/client.c +++ b/libfreerdp/core/client.c @@ -1247,7 +1247,7 @@ int freerdp_channels_client_load(rdpChannels* channels, rdpSettings* settings, PVIRTUALCHANNELENTRY entry, void* data) { int status; - CHANNEL_ENTRY_POINTS_FREERDP EntryPoints; + CHANNEL_ENTRY_POINTS_FREERDP EntryPoints = { 0 }; CHANNEL_CLIENT_DATA* pChannelClientData; if (channels->clientDataCount + 1 > CHANNEL_MAX_COUNT) @@ -1264,7 +1264,7 @@ int freerdp_channels_client_load(rdpChannels* channels, rdpSettings* settings, pChannelClientData = &channels->clientDataList[channels->clientDataCount]; pChannelClientData->entry = entry; - ZeroMemory(&EntryPoints, sizeof(CHANNEL_ENTRY_POINTS_FREERDP)); + EntryPoints.cbSize = sizeof(EntryPoints); EntryPoints.protocolVersion = VIRTUAL_CHANNEL_VERSION_WIN2000; EntryPoints.pVirtualChannelInit = FreeRDP_VirtualChannelInit; diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 088578595..8fe0493a7 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -361,15 +361,16 @@ BOOL rdp_client_connect(rdpRdp* rdp) if ((SelectedProtocol & PROTOCOL_SSL) || (SelectedProtocol == PROTOCOL_RDP)) { if ((settings->Username != NULL) && - ((settings->Password != NULL) || (settings->RedirectionPassword != NULL && - settings->RedirectionPasswordLength > 0))) + ((freerdp_settings_get_string(settings, FreeRDP_Password) != NULL) || + (settings->RedirectionPassword != NULL && + settings->RedirectionPasswordLength > 0))) settings->AutoLogonEnabled = TRUE; } + transport_set_blocking_mode(rdp->transport, FALSE); } /* everything beyond this point is event-driven and non blocking */ transport_set_recv_callbacks(rdp->transport, rdp_recv_callback, rdp); - transport_set_blocking_mode(rdp->transport, FALSE); if (rdp_get_state(rdp) != CONNECTION_STATE_NLA) { @@ -377,7 +378,8 @@ BOOL rdp_client_connect(rdpRdp* rdp) return FALSE; } - for (timeout = 0; timeout < settings->TcpAckTimeout; timeout += 100) + for (timeout = 0; timeout < freerdp_settings_get_uint32(settings, FreeRDP_TcpAckTimeout); + timeout += 100) { if (rdp_check_fds(rdp) < 0) { @@ -772,10 +774,7 @@ BOOL rdp_server_establish_keys(rdpRdp* rdp, wStream* s) } if (!rdp_read_header(rdp, s, &length, &channel_id)) - { - WLog_ERR(TAG, "invalid RDP header"); return FALSE; - } if (!rdp_read_security_header(s, &sec_flags, NULL)) { @@ -1431,9 +1430,9 @@ BOOL rdp_server_reactivate(rdpRdp* rdp) return TRUE; } -int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) +BOOL rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) { - int status = 0; + BOOL status = FALSE; freerdp_peer* client = NULL; const CONNECTION_STATE cstate = rdp_get_state(rdp); @@ -1447,7 +1446,8 @@ int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) } WLog_DBG(TAG, "%s %s --> %s", __FUNCTION__, rdp_get_state_string(rdp), rdp_state_string(state)); - rdp_set_state(rdp, state); + if (!rdp_set_state(rdp, state)) + goto fail; switch (state) { case CONNECTION_STATE_CAPABILITIES_EXCHANGE: @@ -1472,7 +1472,7 @@ int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) IFCALLRET(client->PostConnect, client->connected, client); if (!client->connected) - return -1; + goto fail; } if (rdp_get_state(rdp) >= CONNECTION_STATE_ACTIVE) @@ -1480,7 +1480,7 @@ int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) IFCALLRET(client->Activate, client->activated, client); if (!client->activated) - return -1; + goto fail; } } @@ -1490,6 +1490,8 @@ int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state) break; } + status = TRUE; +fail: return status; } diff --git a/libfreerdp/core/connection.h b/libfreerdp/core/connection.h index 260292e37..9bf1ba560 100644 --- a/libfreerdp/core/connection.h +++ b/libfreerdp/core/connection.h @@ -59,7 +59,7 @@ FREERDP_LOCAL BOOL rdp_server_accept_mcs_channel_join_request(rdpRdp* rdp, wStre FREERDP_LOCAL BOOL rdp_server_accept_confirm_active(rdpRdp* rdp, wStream* s, UINT16 pduLength); FREERDP_LOCAL BOOL rdp_server_establish_keys(rdpRdp* rdp, wStream* s); FREERDP_LOCAL BOOL rdp_server_reactivate(rdpRdp* rdp); -FREERDP_LOCAL int rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state); +FREERDP_LOCAL BOOL rdp_server_transition_to_state(rdpRdp* rdp, CONNECTION_STATE state); FREERDP_LOCAL const char* rdp_get_state_string(rdpRdp* rdp); FREERDP_LOCAL const char* rdp_client_connection_state_string(int state); diff --git a/libfreerdp/core/license.c b/libfreerdp/core/license.c index 52244ace6..865cb96ac 100644 --- a/libfreerdp/core/license.c +++ b/libfreerdp/core/license.c @@ -467,10 +467,7 @@ int license_recv(rdpLicense* license, wStream* s) UINT16 securityFlags = 0; if (!rdp_read_header(license->rdp, s, &length, &channelId)) - { - WLog_ERR(TAG, "Incorrect RDP header."); return -1; - } if (!rdp_read_security_header(s, &securityFlags, &length)) return -1; diff --git a/libfreerdp/core/mcs.c b/libfreerdp/core/mcs.c index e303142f9..9cb9a5bdd 100644 --- a/libfreerdp/core/mcs.c +++ b/libfreerdp/core/mcs.c @@ -190,6 +190,102 @@ static const char* const mcs_result_enumerated[] = }; */ +const char* mcs_domain_pdu_string(DomainMCSPDU pdu) +{ + switch (pdu) + { + case DomainMCSPDU_PlumbDomainIndication: + return "DomainMCSPDU_PlumbDomainIndication"; + case DomainMCSPDU_ErectDomainRequest: + return "DomainMCSPDU_ErectDomainRequest"; + case DomainMCSPDU_MergeChannelsRequest: + return "DomainMCSPDU_MergeChannelsRequest"; + case DomainMCSPDU_MergeChannelsConfirm: + return "DomainMCSPDU_MergeChannelsConfirm"; + case DomainMCSPDU_PurgeChannelsIndication: + return "DomainMCSPDU_PurgeChannelsIndication"; + case DomainMCSPDU_MergeTokensRequest: + return "DomainMCSPDU_MergeTokensRequest"; + case DomainMCSPDU_MergeTokensConfirm: + return "DomainMCSPDU_MergeTokensConfirm"; + case DomainMCSPDU_PurgeTokensIndication: + return "DomainMCSPDU_PurgeTokensIndication"; + case DomainMCSPDU_DisconnectProviderUltimatum: + return "DomainMCSPDU_DisconnectProviderUltimatum"; + case DomainMCSPDU_RejectMCSPDUUltimatum: + return "DomainMCSPDU_RejectMCSPDUUltimatum"; + case DomainMCSPDU_AttachUserRequest: + return "DomainMCSPDU_AttachUserRequest"; + case DomainMCSPDU_AttachUserConfirm: + return "DomainMCSPDU_AttachUserConfirm"; + case DomainMCSPDU_DetachUserRequest: + return "DomainMCSPDU_DetachUserRequest"; + case DomainMCSPDU_DetachUserIndication: + return "DomainMCSPDU_DetachUserIndication"; + case DomainMCSPDU_ChannelJoinRequest: + return "DomainMCSPDU_ChannelJoinRequest"; + case DomainMCSPDU_ChannelJoinConfirm: + return "DomainMCSPDU_ChannelJoinConfirm"; + case DomainMCSPDU_ChannelLeaveRequest: + return "DomainMCSPDU_ChannelLeaveRequest"; + case DomainMCSPDU_ChannelConveneRequest: + return "DomainMCSPDU_ChannelConveneRequest"; + case DomainMCSPDU_ChannelConveneConfirm: + return "DomainMCSPDU_ChannelConveneConfirm"; + case DomainMCSPDU_ChannelDisbandRequest: + return "DomainMCSPDU_ChannelDisbandRequest"; + case DomainMCSPDU_ChannelDisbandIndication: + return "DomainMCSPDU_ChannelDisbandIndication"; + case DomainMCSPDU_ChannelAdmitRequest: + return "DomainMCSPDU_ChannelAdmitRequest"; + case DomainMCSPDU_ChannelAdmitIndication: + return "DomainMCSPDU_ChannelAdmitIndication"; + case DomainMCSPDU_ChannelExpelRequest: + return "DomainMCSPDU_ChannelExpelRequest"; + case DomainMCSPDU_ChannelExpelIndication: + return "DomainMCSPDU_ChannelExpelIndication"; + case DomainMCSPDU_SendDataRequest: + return "DomainMCSPDU_SendDataRequest"; + case DomainMCSPDU_SendDataIndication: + return "DomainMCSPDU_SendDataIndication"; + case DomainMCSPDU_UniformSendDataRequest: + return "DomainMCSPDU_UniformSendDataRequest"; + case DomainMCSPDU_UniformSendDataIndication: + return "DomainMCSPDU_UniformSendDataIndication"; + case DomainMCSPDU_TokenGrabRequest: + return "DomainMCSPDU_TokenGrabRequest"; + case DomainMCSPDU_TokenGrabConfirm: + return "DomainMCSPDU_TokenGrabConfirm"; + case DomainMCSPDU_TokenInhibitRequest: + return "DomainMCSPDU_TokenInhibitRequest"; + case DomainMCSPDU_TokenInhibitConfirm: + return "DomainMCSPDU_TokenInhibitConfirm"; + case DomainMCSPDU_TokenGiveRequest: + return "DomainMCSPDU_TokenGiveRequest"; + case DomainMCSPDU_TokenGiveIndication: + return "DomainMCSPDU_TokenGiveIndication"; + case DomainMCSPDU_TokenGiveResponse: + return "DomainMCSPDU_TokenGiveResponse"; + case DomainMCSPDU_TokenGiveConfirm: + return "DomainMCSPDU_TokenGiveConfirm"; + case DomainMCSPDU_TokenPleaseRequest: + return "DomainMCSPDU_TokenPleaseRequest"; + case DomainMCSPDU_TokenPleaseConfirm: + return "DomainMCSPDU_TokenPleaseConfirm"; + case DomainMCSPDU_TokenReleaseRequest: + return "DomainMCSPDU_TokenReleaseRequest"; + case DomainMCSPDU_TokenReleaseConfirm: + return "DomainMCSPDU_TokenReleaseConfirm"; + case DomainMCSPDU_TokenTestRequest: + return "DomainMCSPDU_TokenTestRequest"; + case DomainMCSPDU_TokenTestConfirm: + return "DomainMCSPDU_TokenTestConfirm"; + case DomainMCSPDU_enum_length: + return "DomainMCSPDU_enum_length"; + default: + return "DomainMCSPDU_UNKNOWN"; + }; +} static BOOL mcs_merge_domain_parameters(DomainParameters* targetParameters, DomainParameters* minimumParameters, DomainParameters* maximumParameters, @@ -197,8 +293,8 @@ static BOOL mcs_merge_domain_parameters(DomainParameters* targetParameters, static BOOL mcs_write_connect_initial(wStream* s, rdpMcs* mcs, wStream* userData); static BOOL mcs_write_connect_response(wStream* s, rdpMcs* mcs, wStream* userData); -static BOOL mcs_read_domain_mcspdu_header(wStream* s, enum DomainMCSPDU* domainMCSPDU, - UINT16* length); +static BOOL mcs_read_domain_mcspdu_header(wStream* s, DomainMCSPDU domainMCSPDU, UINT16* length, + DomainMCSPDU* actual); static int mcs_initialize_client_channels(rdpMcs* mcs, const rdpSettings* settings) { @@ -235,11 +331,15 @@ static int mcs_initialize_client_channels(rdpMcs* mcs, const rdpSettings* settin * @return */ -BOOL mcs_read_domain_mcspdu_header(wStream* s, enum DomainMCSPDU* domainMCSPDU, UINT16* length) +BOOL mcs_read_domain_mcspdu_header(wStream* s, DomainMCSPDU domainMCSPDU, UINT16* length, + DomainMCSPDU* actual) { UINT16 li; BYTE choice; - enum DomainMCSPDU MCSPDU; + DomainMCSPDU MCSPDU; + + if (actual) + *actual = DomainMCSPDU_invalid; if (!s || !domainMCSPDU || !length) return FALSE; @@ -250,15 +350,19 @@ BOOL mcs_read_domain_mcspdu_header(wStream* s, enum DomainMCSPDU* domainMCSPDU, if (!tpdu_read_data(s, &li, *length)) return FALSE; - MCSPDU = *domainMCSPDU; - if (!per_read_choice(s, &choice)) return FALSE; - *domainMCSPDU = (choice >> 2); + MCSPDU = (choice >> 2); + if (actual) + *actual = MCSPDU; - if (*domainMCSPDU != MCSPDU) + if (domainMCSPDU != MCSPDU) + { + WLog_ERR(TAG, "Expected MCS %s, got %s", mcs_domain_pdu_string(domainMCSPDU), + mcs_domain_pdu_string(MCSPDU)); return FALSE; + } return TRUE; } @@ -270,7 +374,7 @@ BOOL mcs_read_domain_mcspdu_header(wStream* s, enum DomainMCSPDU* domainMCSPDU, * @param length TPKT length */ -void mcs_write_domain_mcspdu_header(wStream* s, enum DomainMCSPDU domainMCSPDU, UINT16 length, +void mcs_write_domain_mcspdu_header(wStream* s, DomainMCSPDU domainMCSPDU, UINT16 length, BYTE options) { tpkt_write_header(s, length); @@ -682,7 +786,7 @@ out: * @param mcs mcs module */ -BOOL mcs_send_connect_initial(rdpMcs* mcs) +static BOOL mcs_send_connect_initial(rdpMcs* mcs) { int status = -1; size_t length; @@ -880,14 +984,11 @@ BOOL mcs_recv_erect_domain_request(rdpMcs* mcs, wStream* s) UINT16 length; UINT32 subHeight; UINT32 subInterval; - enum DomainMCSPDU MCSPDU; if (!mcs || !s) return FALSE; - MCSPDU = DomainMCSPDU_ErectDomainRequest; - - if (!mcs_read_domain_mcspdu_header(s, &MCSPDU, &length)) + if (!mcs_read_domain_mcspdu_header(s, DomainMCSPDU_ErectDomainRequest, &length, NULL)) return FALSE; if (!per_read_integer(s, &subHeight)) /* subHeight (INTEGER) */ @@ -941,13 +1042,11 @@ BOOL mcs_send_erect_domain_request(rdpMcs* mcs) BOOL mcs_recv_attach_user_request(rdpMcs* mcs, wStream* s) { UINT16 length; - enum DomainMCSPDU MCSPDU; if (!mcs || !s) return FALSE; - MCSPDU = DomainMCSPDU_AttachUserRequest; - if (!mcs_read_domain_mcspdu_header(s, &MCSPDU, &length)) + if (!mcs_read_domain_mcspdu_header(s, DomainMCSPDU_AttachUserRequest, &length, NULL)) return FALSE; return tpkt_ensure_stream_consumed(s, length); } @@ -992,13 +1091,11 @@ BOOL mcs_recv_attach_user_confirm(rdpMcs* mcs, wStream* s) { BYTE result; UINT16 length; - enum DomainMCSPDU MCSPDU; if (!mcs || !s) return FALSE; - MCSPDU = DomainMCSPDU_AttachUserConfirm; - if (!mcs_read_domain_mcspdu_header(s, &MCSPDU, &length)) + if (!mcs_read_domain_mcspdu_header(s, DomainMCSPDU_AttachUserConfirm, &length, NULL)) return FALSE; if (!per_read_enumerated(s, &result, MCS_Result_enum_length)) /* result */ return FALSE; @@ -1051,13 +1148,11 @@ BOOL mcs_recv_channel_join_request(rdpMcs* mcs, wStream* s, UINT16* channelId) { UINT16 length; UINT16 userId; - enum DomainMCSPDU MCSPDU; if (!mcs || !s || !channelId) return FALSE; - MCSPDU = DomainMCSPDU_ChannelJoinRequest; - if (!mcs_read_domain_mcspdu_header(s, &MCSPDU, &length)) + if (!mcs_read_domain_mcspdu_header(s, DomainMCSPDU_ChannelJoinRequest, &length, NULL)) return FALSE; if (!per_read_integer16(s, &userId, MCS_BASE_CHANNEL_ID) && (userId == mcs->userId)) @@ -1113,13 +1208,11 @@ BOOL mcs_recv_channel_join_confirm(rdpMcs* mcs, wStream* s, UINT16* channelId) BYTE result; UINT16 initiator; UINT16 requested; - enum DomainMCSPDU MCSPDU; if (!mcs || !s || !channelId) return FALSE; - MCSPDU = DomainMCSPDU_ChannelJoinConfirm; - if (!mcs_read_domain_mcspdu_header(s, &MCSPDU, &length)) + if (!mcs_read_domain_mcspdu_header(s, DomainMCSPDU_ChannelJoinConfirm, &length, NULL)) return FALSE; if (!per_read_enumerated(s, &result, MCS_Result_enum_length)) /* result */ @@ -1257,6 +1350,8 @@ BOOL mcs_client_begin(rdpMcs* mcs) if (!context) return FALSE; + /* First transition state, we need this to trigger session recording */ + rdp_client_transition_to_state(context->rdp, CONNECTION_STATE_MCS_CONNECT); if (!mcs_send_connect_initial(mcs)) { freerdp_set_last_error_if_not(context, FREERDP_ERROR_MCS_CONNECT_INITIAL_ERROR); @@ -1265,7 +1360,6 @@ BOOL mcs_client_begin(rdpMcs* mcs) return FALSE; } - rdp_client_transition_to_state(context->rdp, CONNECTION_STATE_MCS_CONNECT); return TRUE; } diff --git a/libfreerdp/core/mcs.h b/libfreerdp/core/mcs.h index 5856daa29..e0df8a958 100644 --- a/libfreerdp/core/mcs.h +++ b/libfreerdp/core/mcs.h @@ -57,8 +57,9 @@ enum MCS_Result MCS_Result_enum_length = 16 }; -enum DomainMCSPDU +typedef enum { + DomainMCSPDU_invalid = -1, DomainMCSPDU_PlumbDomainIndication = 0, DomainMCSPDU_ErectDomainRequest = 1, DomainMCSPDU_MergeChannelsRequest = 2, @@ -103,7 +104,7 @@ enum DomainMCSPDU DomainMCSPDU_TokenTestRequest = 41, DomainMCSPDU_TokenTestConfirm = 42, DomainMCSPDU_enum_length = 43 -}; +} DomainMCSPDU; typedef struct { @@ -154,8 +155,8 @@ struct rdp_mcs #define MCS_TYPE_CONNECT_INITIAL 0x65 #define MCS_TYPE_CONNECT_RESPONSE 0x66 +const char* mcs_domain_pdu_string(DomainMCSPDU pdu); FREERDP_LOCAL BOOL mcs_recv_connect_initial(rdpMcs* mcs, wStream* s); -FREERDP_LOCAL BOOL mcs_send_connect_initial(rdpMcs* mcs); FREERDP_LOCAL BOOL mcs_recv_connect_response(rdpMcs* mcs, wStream* s); FREERDP_LOCAL BOOL mcs_send_connect_response(rdpMcs* mcs); FREERDP_LOCAL BOOL mcs_recv_erect_domain_request(rdpMcs* mcs, wStream* s); @@ -171,7 +172,7 @@ FREERDP_LOCAL BOOL mcs_send_channel_join_confirm(rdpMcs* mcs, UINT16 channelId); FREERDP_LOCAL BOOL mcs_recv_disconnect_provider_ultimatum(rdpMcs* mcs, wStream* s, int* reason); FREERDP_LOCAL BOOL mcs_send_disconnect_provider_ultimatum(rdpMcs* mcs); -FREERDP_LOCAL void mcs_write_domain_mcspdu_header(wStream* s, enum DomainMCSPDU domainMCSPDU, +FREERDP_LOCAL void mcs_write_domain_mcspdu_header(wStream* s, DomainMCSPDU domainMCSPDU, UINT16 length, BYTE options); FREERDP_LOCAL BOOL mcs_client_begin(rdpMcs* mcs); diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index c6d3d4585..763a41481 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -220,7 +220,7 @@ static BOOL freerdp_peer_set_state(freerdp_peer* client, CONNECTION_STATE state) { WINPR_ASSERT(client); WINPR_ASSERT(client->context); - return rdp_server_transition_to_state(client->context->rdp, state) >= 0; + return rdp_server_transition_to_state(client->context->rdp, state); } static BOOL freerdp_peer_initialize(freerdp_peer* client) @@ -431,10 +431,7 @@ static int peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s) WINPR_ASSERT(settings); if (!rdp_read_header(rdp, s, &length, &channelId)) - { - WLog_ERR(TAG, "Incorrect RDP header."); return -1; - } rdp->inPackets++; if (freerdp_shall_disconnect(rdp->instance)) diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index c6e19c2fe..85ae8399d 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -26,6 +26,7 @@ #include "info.h" #include "utils.h" +#include "mcs.h" #include "redirection.h" #include @@ -405,8 +406,8 @@ BOOL rdp_read_header(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* channelId) BYTE code; BYTE choice; UINT16 initiator; - enum DomainMCSPDU MCSPDU; - enum DomainMCSPDU domainMCSPDU; + DomainMCSPDU MCSPDU; + DomainMCSPDU domainMCSPDU; MCSPDU = (rdp->settings->ServerMode) ? DomainMCSPDU_SendDataRequest : DomainMCSPDU_SendDataIndication; @@ -425,27 +426,40 @@ BOOL rdp_read_header(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* channelId) return TRUE; } + WLog_WARN(TAG, "Unexpected X224 TPDU type %s [%08" PRIx32 "] instead of %s", + tpdu_type_to_string(code), code, tpdu_type_to_string(X224_TPDU_DATA)); return FALSE; } if (!per_read_choice(s, &choice)) return FALSE; - domainMCSPDU = (enum DomainMCSPDU)(choice >> 2); + domainMCSPDU = (DomainMCSPDU)(choice >> 2); if (domainMCSPDU != MCSPDU) { if (domainMCSPDU != DomainMCSPDU_DisconnectProviderUltimatum) + { + WLog_WARN(TAG, "Received %s instead of %s", mcs_domain_pdu_string(domainMCSPDU), + mcs_domain_pdu_string(MCSPDU)); return FALSE; + } } MCSPDU = domainMCSPDU; if (*length < 8U) + { + WLog_WARN(TAG, "TPDU invalid length, got %" PRIu16 ", expected at least 8", *length); return FALSE; + } if ((*length - 8U) > Stream_GetRemainingLength(s)) + { + WLog_WARN(TAG, "TPDU invalid length, got %" PRIuz ", expected %" PRIu16, + Stream_GetRemainingLength(s), *length - 8); return FALSE; + } if (MCSPDU == DomainMCSPDU_DisconnectProviderUltimatum) { @@ -486,7 +500,11 @@ BOOL rdp_read_header(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* channelId) } if (Stream_GetRemainingLength(s) < 5) + { + WLog_WARN(TAG, "TPDU packet short length, got %" PRIuz ", expected at least 5", + Stream_GetRemainingLength(s)); return FALSE; + } if (!per_read_integer16(s, &initiator, MCS_BASE_CHANNEL_ID)) /* initiator (UserId) */ return FALSE; @@ -500,7 +518,11 @@ BOOL rdp_read_header(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* channelId) return FALSE; if (*length > Stream_GetRemainingLength(s)) + { + WLog_WARN(TAG, "TPDU invalid length, got %" PRIuz ", expected %" PRIu16, + Stream_GetRemainingLength(s), *length); return FALSE; + } return TRUE; } @@ -516,7 +538,7 @@ BOOL rdp_read_header(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* channelId) void rdp_write_header(rdpRdp* rdp, wStream* s, UINT16 length, UINT16 channelId) { int body_length; - enum DomainMCSPDU MCSPDU; + DomainMCSPDU MCSPDU; MCSPDU = (rdp->settings->ServerMode) ? DomainMCSPDU_SendDataIndication : DomainMCSPDU_SendDataRequest; @@ -1323,10 +1345,7 @@ static int rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s) UINT16 securityFlags = 0; if (!rdp_read_header(rdp, s, &length, &channelId)) - { - WLog_ERR(TAG, "Incorrect RDP header."); return -1; - } if (freerdp_shall_disconnect(rdp->instance)) return 0; @@ -1809,6 +1828,7 @@ rdpRdp* rdp_new(rdpContext* context) rdp->settings = context->settings; rdp->settings->instance = context->instance; + context->settings = rdp->settings; if (context->instance) context->settings->instance = context->instance; else if (context->peer) @@ -2048,6 +2068,7 @@ BOOL rdp_set_io_callback_context(rdpRdp* rdp, void* usercontext) { WINPR_ASSERT(rdp); rdp->ioContext = usercontext; + return TRUE; } void* rdp_get_io_callback_context(rdpRdp* rdp) diff --git a/libfreerdp/core/tpdu.c b/libfreerdp/core/tpdu.c index 1386e0e69..7bc3518ff 100644 --- a/libfreerdp/core/tpdu.c +++ b/libfreerdp/core/tpdu.c @@ -75,7 +75,11 @@ static void tpdu_write_header(wStream* s, UINT16 length, BYTE code); BOOL tpdu_read_header(wStream* s, BYTE* code, BYTE* li, UINT16 tpktlength) { if (Stream_GetRemainingLength(s) < 3) + { + WLog_WARN(TAG, "tpdu invalid data, got %" PRIuz ", require at least 3 more", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Read_UINT8(s, *li); /* LI */ Stream_Read_UINT8(s, *code); /* Code */ @@ -246,3 +250,22 @@ BOOL tpdu_read_data(wStream* s, UINT16* LI, UINT16 tpktlength) return TRUE; } + +const char* tpdu_type_to_string(int type) +{ + switch (type) + { + case X224_TPDU_CONNECTION_REQUEST: + return "X224_TPDU_CONNECTION_REQUEST"; + case X224_TPDU_CONNECTION_CONFIRM: + return "X224_TPDU_CONNECTION_CONFIRM"; + case X224_TPDU_DISCONNECT_REQUEST: + return "X224_TPDU_DISCONNECT_REQUEST"; + case X224_TPDU_DATA: + return "X224_TPDU_DATA"; + case X224_TPDU_ERROR: + return "X224_TPDU_ERROR"; + default: + return "X224_TPDU_UNKNOWN"; + } +} diff --git a/libfreerdp/core/tpdu.h b/libfreerdp/core/tpdu.h index 8e28bbea7..ea629e758 100644 --- a/libfreerdp/core/tpdu.h +++ b/libfreerdp/core/tpdu.h @@ -42,6 +42,7 @@ enum X224_TPDU_TYPE #define TPDU_CONNECTION_CONFIRM_LENGTH (TPKT_HEADER_LENGTH + TPDU_CONNECTION_CONFIRM_HEADER_LENGTH) #define TPDU_DISCONNECT_REQUEST_LENGTH (TPKT_HEADER_LENGTH + TPDU_DISCONNECT_REQUEST_HEADER_LENGTH) +const char* tpdu_type_to_string(int type); FREERDP_LOCAL BOOL tpdu_read_header(wStream* s, BYTE* code, BYTE* li, UINT16 tpktlength); FREERDP_LOCAL BOOL tpdu_read_connection_request(wStream* s, BYTE* li, UINT16 tpktlength); FREERDP_LOCAL void tpdu_write_connection_request(wStream* s, UINT16 length); diff --git a/libfreerdp/core/tpkt.c b/libfreerdp/core/tpkt.c index f0eb775d8..409791239 100644 --- a/libfreerdp/core/tpkt.c +++ b/libfreerdp/core/tpkt.c @@ -89,7 +89,11 @@ BOOL tpkt_read_header(wStream* s, UINT16* length) BYTE version; if (Stream_GetRemainingLength(s) < 1) + { + WLog_WARN(TAG, "tpkt invalid data, got %" PRIuz ", require at least 1 more", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Peek_UINT8(s, version); @@ -98,7 +102,11 @@ BOOL tpkt_read_header(wStream* s, UINT16* length) size_t slen; UINT16 len; if (Stream_GetRemainingLength(s) < 4) + { + WLog_WARN(TAG, "tpkt invalid data, got %" PRIuz ", require at least 4 more", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Seek(s, 2); Stream_Read_UINT16_BE(s, len); diff --git a/libfreerdp/crypto/per.c b/libfreerdp/crypto/per.c index e00377d13..3275a57b6 100644 --- a/libfreerdp/crypto/per.c +++ b/libfreerdp/crypto/per.c @@ -21,6 +21,8 @@ #include +#include +#define TAG FREERDP_TAG("crypto.per") /** * Read PER length. * @param s stream @@ -33,14 +35,22 @@ BOOL per_read_length(wStream* s, UINT16* length) BYTE byte; if (Stream_GetRemainingLength(s) < 1) + { + WLog_WARN(TAG, "PER length invalid data, got %" PRIuz ", require at least 1 more", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Read_UINT8(s, byte); if (byte & 0x80) { if (Stream_GetRemainingLength(s) < 1) + { + WLog_WARN(TAG, "PER length invalid data, got %" PRIuz ", require at least 1 more", + Stream_GetRemainingLength(s)); return FALSE; + } byte &= ~(0x80); *length = (byte << 8); @@ -88,7 +98,11 @@ BOOL per_write_length(wStream* s, UINT16 length) BOOL per_read_choice(wStream* s, BYTE* choice) { if (Stream_GetRemainingLength(s) < 1) + { + WLog_WARN(TAG, "PER choice invalid data, got %" PRIuz ", require at least 1 more", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Read_UINT8(s, *choice); return TRUE; @@ -272,12 +286,20 @@ BOOL per_write_integer(wStream* s, UINT32 integer) BOOL per_read_integer16(wStream* s, UINT16* integer, UINT16 min) { if (Stream_GetRemainingLength(s) < 2) + { + WLog_WARN(TAG, "PER uint16 invalid data, got %" PRIuz ", require at least 2", + Stream_GetRemainingLength(s)); return FALSE; + } Stream_Read_UINT16_BE(s, *integer); - if (*integer + min > 0xFFFF) + if (*integer > UINT16_MAX - min) + { + WLog_WARN(TAG, "PER uint16 invalid value %" PRIu16 " > %" PRIu16, *integer, + UINT16_MAX - min); return FALSE; + } *integer += min;