diff --git a/libfreerdp-core/mcs.c b/libfreerdp-core/mcs.c index 58c66644f..28956a978 100644 --- a/libfreerdp-core/mcs.c +++ b/libfreerdp-core/mcs.c @@ -220,7 +220,7 @@ boolean mcs_connect(rdpMcs* mcs) * @return */ -boolean mcs_read_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU* domainMCSPDU, int* length) +boolean mcs_read_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU* domainMCSPDU, uint16* length) { uint8 choice; enum DomainMCSPDU MCSPDU; @@ -247,7 +247,7 @@ boolean mcs_read_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU* domainMCSPDU * @param length TPKT length */ -void mcs_write_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU domainMCSPDU, int length, uint8 options) +void mcs_write_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU domainMCSPDU, uint16 length, uint8 options) { tpkt_write_header(s, length); tpdu_write_data(s); @@ -605,7 +605,7 @@ boolean mcs_send_connect_response(rdpMcs* mcs) boolean mcs_read_erect_domain_request(rdpMcs* mcs, STREAM* s) { - int length; + uint16 length; enum DomainMCSPDU MCSPDU; MCSPDU = DomainMCSPDU_ErectDomainRequest; @@ -624,7 +624,7 @@ boolean mcs_read_erect_domain_request(rdpMcs* mcs, STREAM* s) void mcs_send_erect_domain_request(rdpMcs* mcs) { STREAM* s; - int length = 12; + uint16 length = 12; s = transport_send_stream_init(mcs->transport, length); mcs_write_domain_mcspdu_header(s, DomainMCSPDU_ErectDomainRequest, length, 0); @@ -644,7 +644,7 @@ void mcs_send_erect_domain_request(rdpMcs* mcs) boolean mcs_read_attach_user_request(rdpMcs* mcs, STREAM* s) { - int length; + uint16 length; enum DomainMCSPDU MCSPDU; MCSPDU = DomainMCSPDU_AttachUserRequest; @@ -663,7 +663,7 @@ boolean mcs_read_attach_user_request(rdpMcs* mcs, STREAM* s) void mcs_send_attach_user_request(rdpMcs* mcs) { STREAM* s; - int length = 8; + uint16 length = 8; s = transport_send_stream_init(mcs->transport, length); mcs_write_domain_mcspdu_header(s, DomainMCSPDU_AttachUserRequest, length, 0); @@ -680,7 +680,7 @@ void mcs_send_attach_user_request(rdpMcs* mcs) void mcs_recv_attach_user_confirm(rdpMcs* mcs) { STREAM* s; - int length; + uint16 length; uint8 result; enum DomainMCSPDU MCSPDU; @@ -703,7 +703,7 @@ void mcs_recv_attach_user_confirm(rdpMcs* mcs) boolean mcs_send_attach_user_confirm(rdpMcs* mcs) { STREAM* s; - int length = 11; + uint16 length = 11; s = transport_send_stream_init(mcs->transport, length); @@ -727,7 +727,7 @@ boolean mcs_send_attach_user_confirm(rdpMcs* mcs) boolean mcs_read_channel_join_request(rdpMcs* mcs, STREAM* s, uint16* channel_id) { - int length; + uint16 length; enum DomainMCSPDU MCSPDU; uint16 user_id; @@ -755,7 +755,7 @@ boolean mcs_read_channel_join_request(rdpMcs* mcs, STREAM* s, uint16* channel_id void mcs_send_channel_join_request(rdpMcs* mcs, uint16 channel_id) { STREAM* s; - int length = 12; + uint16 length = 12; s = transport_send_stream_init(mcs->transport, 12); mcs_write_domain_mcspdu_header(s, DomainMCSPDU_ChannelJoinRequest, length, 0); @@ -775,7 +775,7 @@ void mcs_send_channel_join_request(rdpMcs* mcs, uint16 channel_id) void mcs_recv_channel_join_confirm(rdpMcs* mcs) { STREAM* s; - int length; + uint16 length; uint8 result; uint16 initiator; uint16 requested; @@ -803,7 +803,7 @@ void mcs_recv_channel_join_confirm(rdpMcs* mcs) boolean mcs_send_channel_join_confirm(rdpMcs* mcs, uint16 channel_id) { STREAM* s; - int length = 15; + uint16 length = 15; s = transport_send_stream_init(mcs->transport, 15); mcs_write_domain_mcspdu_header(s, DomainMCSPDU_ChannelJoinConfirm, length, 2); diff --git a/libfreerdp-core/mcs.h b/libfreerdp-core/mcs.h index dfd2cb825..744d0b6b0 100644 --- a/libfreerdp-core/mcs.h +++ b/libfreerdp-core/mcs.h @@ -148,8 +148,8 @@ boolean mcs_read_channel_join_request(rdpMcs* mcs, STREAM* s, uint16* channel_id void mcs_send_channel_join_request(rdpMcs* mcs, uint16 channel_id); void mcs_recv_channel_join_confirm(rdpMcs* mcs); boolean mcs_send_channel_join_confirm(rdpMcs* mcs, uint16 channel_id); -boolean mcs_read_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU* domainMCSPDU, int* length); -void mcs_write_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU domainMCSPDU, int length, uint8 options); +boolean mcs_read_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU* domainMCSPDU, uint16* length); +void mcs_write_domain_mcspdu_header(STREAM* s, enum DomainMCSPDU domainMCSPDU, uint16 length, uint8 options); rdpMcs* mcs_new(rdpTransport* transport); void mcs_free(rdpMcs* mcs); diff --git a/libfreerdp-core/peer.c b/libfreerdp-core/peer.c index d225ef634..5268b5d14 100644 --- a/libfreerdp-core/peer.c +++ b/libfreerdp-core/peer.c @@ -23,6 +23,7 @@ static boolean freerdp_peer_initialize(freerdp_peer* client) { rdpPeer* peer = (rdpPeer*)client->peer; + peer->rdp->server_mode = True; peer->rdp->state = CONNECTION_STATE_INITIAL; return True; diff --git a/libfreerdp-core/rdp.c b/libfreerdp-core/rdp.c index 07d3eb5e3..98ccd24e7 100644 --- a/libfreerdp-core/rdp.c +++ b/libfreerdp-core/rdp.c @@ -169,6 +169,32 @@ STREAM* rdp_data_pdu_init(rdpRdp* rdp) return s; } +/** + * Read an RDP packet header.\n + * @param rdp rdp module + * @param s stream + * @param length RDP packet length + * @param channel_id channel id + */ + +boolean rdp_read_header(rdpRdp* rdp, STREAM* s, uint16* length, uint16* channel_id) +{ + uint16 initiator; + enum DomainMCSPDU MCSPDU; + + MCSPDU = (rdp->server_mode ? DomainMCSPDU_SendDataRequest : DomainMCSPDU_SendDataIndication); + mcs_read_domain_mcspdu_header(s, &MCSPDU, length); + + per_read_integer16(s, &initiator, MCS_BASE_CHANNEL_ID); /* initiator (UserId) */ + per_read_integer16(s, channel_id, 0); /* channelId */ + stream_seek(s, 1); /* dataPriority + Segmentation (0x70) */ + per_read_length(s, length); /* userData (OCTET_STRING) */ + if (*length > stream_get_left(s)) + return False; + + return True; +} + /** * Write an RDP packet header.\n * @param rdp rdp module @@ -177,9 +203,13 @@ STREAM* rdp_data_pdu_init(rdpRdp* rdp) * @param channel_id channel id */ -void rdp_write_header(rdpRdp* rdp, STREAM* s, int length, uint16 channel_id) +void rdp_write_header(rdpRdp* rdp, STREAM* s, uint16 length, uint16 channel_id) { - mcs_write_domain_mcspdu_header(s, DomainMCSPDU_SendDataRequest, length, 0); + enum DomainMCSPDU MCSPDU; + + MCSPDU = (rdp->server_mode ? DomainMCSPDU_SendDataIndication : DomainMCSPDU_SendDataRequest); + + mcs_write_domain_mcspdu_header(s, MCSPDU, length, 0); per_write_integer16(s, rdp->mcs->user_id, MCS_BASE_CHANNEL_ID); /* initiator */ per_write_integer16(s, channel_id, 0); /* channelId */ stream_write_uint8(s, 0x70); /* dataPriority + segmentation */ @@ -197,7 +227,7 @@ void rdp_write_header(rdpRdp* rdp, STREAM* s, int length, uint16 channel_id) void rdp_send(rdpRdp* rdp, STREAM* s, uint16 channel_id) { - int length; + uint16 length; length = stream_get_length(s); stream_set_pos(s, 0); @@ -210,7 +240,7 @@ void rdp_send(rdpRdp* rdp, STREAM* s, uint16 channel_id) void rdp_send_pdu(rdpRdp* rdp, STREAM* s, uint16 type, uint16 channel_id) { - int length; + uint16 length; length = stream_get_length(s); stream_set_pos(s, 0); @@ -224,7 +254,7 @@ void rdp_send_pdu(rdpRdp* rdp, STREAM* s, uint16 type, uint16 channel_id) void rdp_send_data_pdu(rdpRdp* rdp, STREAM* s, uint8 type, uint16 channel_id) { - int length; + uint16 length; length = stream_get_length(s); stream_set_pos(s, 0); @@ -351,24 +381,20 @@ void rdp_read_data_pdu(rdpRdp* rdp, STREAM* s) * @param s stream */ -static void rdp_process_tpkt_pdu(rdpRdp* rdp, STREAM* s) +static void rdp_read_tpkt_pdu(rdpRdp* rdp, STREAM* s) { - int length; + uint16 length; uint16 pduType; uint16 pduLength; - uint16 initiator; uint16 channelId; uint16 sec_flags; boolean processed; - enum DomainMCSPDU MCSPDU; - MCSPDU = DomainMCSPDU_SendDataIndication; - mcs_read_domain_mcspdu_header(s, &MCSPDU, &length); - - per_read_integer16(s, &initiator, MCS_BASE_CHANNEL_ID); /* initiator (UserId) */ - per_read_integer16(s, &channelId, 0); /* channelId */ - stream_seek(s, 1); /* dataPriority + Segmentation (0x70) */ - per_read_length(s, &pduLength); /* userData (OCTET_STRING) */ + if (!rdp_read_header(rdp, s, &length, &channelId)) + { + printf("Incorrect RDP header.\n"); + return; + } if (rdp->licensed != True) { @@ -433,7 +459,7 @@ static void rdp_process_tpkt_pdu(rdpRdp* rdp, STREAM* s) } } -static void rdp_process_fastpath_pdu(rdpRdp* rdp, STREAM* s) +static void rdp_read_fastpath_pdu(rdpRdp* rdp, STREAM* s) { uint16 length; @@ -454,12 +480,12 @@ static void rdp_process_fastpath_pdu(rdpRdp* rdp, STREAM* s) fastpath_recv_updates(rdp->fastpath, s); } -static void rdp_process_pdu(rdpRdp* rdp, STREAM* s) +static void rdp_read_pdu(rdpRdp* rdp, STREAM* s) { if (tpkt_verify_header(s)) - rdp_process_tpkt_pdu(rdp, s); + rdp_read_tpkt_pdu(rdp, s); else - rdp_process_fastpath_pdu(rdp, s); + rdp_read_fastpath_pdu(rdp, s); } /** @@ -474,14 +500,14 @@ void rdp_recv(rdpRdp* rdp) s = transport_recv_stream_init(rdp->transport, 4096); transport_read(rdp->transport, s); - rdp_process_pdu(rdp, s); + rdp_read_pdu(rdp, s); } static int rdp_recv_callback(rdpTransport* transport, STREAM* s, void* extra) { rdpRdp* rdp = (rdpRdp*) extra; - rdp_process_pdu(rdp, s); + rdp_read_pdu(rdp, s); return 1; } diff --git a/libfreerdp-core/rdp.h b/libfreerdp-core/rdp.h index eff7e02ed..1d2107108 100644 --- a/libfreerdp-core/rdp.h +++ b/libfreerdp-core/rdp.h @@ -115,6 +115,7 @@ struct rdp_rdp { boolean licensed; boolean activated; + boolean server_mode; int state; struct rdp_mcs* mcs; struct rdp_nego* nego; @@ -137,7 +138,9 @@ void rdp_read_share_data_header(STREAM* s, uint16* length, uint8* type, uint32* void rdp_write_share_data_header(STREAM* s, uint16 length, uint8 type, uint32 share_id); STREAM* rdp_send_stream_init(rdpRdp* rdp); -void rdp_write_header(rdpRdp* rdp, STREAM* s, int length, uint16 channel_id); + +boolean rdp_read_header(rdpRdp* rdp, STREAM* s, uint16* length, uint16* channel_id); +void rdp_write_header(rdpRdp* rdp, STREAM* s, uint16 length, uint16 channel_id); STREAM* rdp_pdu_init(rdpRdp* rdp); void rdp_send_pdu(rdpRdp* rdp, STREAM* s, uint16 type, uint16 channel_id); diff --git a/libfreerdp-core/tpkt.c b/libfreerdp-core/tpkt.c index a07d1cb46..4eec38e3f 100644 --- a/libfreerdp-core/tpkt.c +++ b/libfreerdp-core/tpkt.c @@ -103,7 +103,7 @@ uint16 tpkt_read_header(STREAM* s) * @param length */ -void tpkt_write_header(STREAM* s, int length) +void tpkt_write_header(STREAM* s, uint16 length) { stream_write_uint8(s, 3); /* version */ stream_write_uint8(s, 0); /* reserved */ diff --git a/libfreerdp-core/tpkt.h b/libfreerdp-core/tpkt.h index c754c1ccf..331bafd92 100644 --- a/libfreerdp-core/tpkt.h +++ b/libfreerdp-core/tpkt.h @@ -29,6 +29,6 @@ boolean tpkt_verify_header(STREAM* s); uint16 tpkt_read_header(STREAM* s); -void tpkt_write_header(STREAM* s, int length); +void tpkt_write_header(STREAM* s, uint16 length); #endif /* __TPKT_H */