From 0ea1dc43ec6fed48cc0148b1b3ec3e2e0512293c Mon Sep 17 00:00:00 2001 From: Hardening Date: Tue, 20 May 2014 22:39:21 +0200 Subject: [PATCH 01/11] Add a ringbuffer implementation targetting byte sending This adds a ringbuffer implementation that targets bytes sending. The ringbuffer can grow when there's not enough room, that's why it's not thread-safe (locking must be done externally). It will be shrinked to its initial size as soon as the used bytes are the half of the initial size. --- include/freerdp/utils/ringbuffer.h | 122 ++++++++++++ libfreerdp/utils/CMakeLists.txt | 7 + libfreerdp/utils/ringbuffer.c | 250 +++++++++++++++++++++++++ libfreerdp/utils/test/CMakeLists.txt | 34 ++++ libfreerdp/utils/test/TestRingBuffer.c | 228 ++++++++++++++++++++++ 5 files changed, 641 insertions(+) create mode 100644 include/freerdp/utils/ringbuffer.h create mode 100644 libfreerdp/utils/ringbuffer.c create mode 100644 libfreerdp/utils/test/CMakeLists.txt create mode 100644 libfreerdp/utils/test/TestRingBuffer.c diff --git a/include/freerdp/utils/ringbuffer.h b/include/freerdp/utils/ringbuffer.h new file mode 100644 index 000000000..099ba8ba1 --- /dev/null +++ b/include/freerdp/utils/ringbuffer.h @@ -0,0 +1,122 @@ +/** + * Copyright © 2014 Thincast Technologies GmbH + * Copyright © 2014 Hardening + * + * Permission to use, copy, modify, distribute, and sell this software and + * its documentation for any purpose is hereby granted without fee, provided + * that the above copyright notice appear in all copies and that both that + * copyright notice and this permission notice appear in supporting + * documentation, and that the name of the copyright holders not be used in + * advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. The copyright holders make + * no representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS + * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY + * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER + * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF + * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef __RINGBUFFER_H___ +#define __RINGBUFFER_H___ + +#include + +/** @brief ring buffer meta data */ +struct _RingBuffer { + size_t initialSize; + size_t freeSize; + size_t size; + size_t readPtr; + size_t writePtr; + BYTE *buffer; +}; +typedef struct _RingBuffer RingBuffer; + + +/** @brief a piece of data in the ring buffer, exactly like a glibc iovec */ +struct _DataChunk { + size_t size; + const BYTE *data; +}; +typedef struct _DataChunk DataChunk; + +/** initialise a ringbuffer + * @param initialSize the initial capacity of the ringBuffer + * @return if the initialisation was successful + */ +BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize); + +/** destroys internal data used by this ringbuffer + * @param ringbuffer + */ +void ringbuffer_destroy(RingBuffer *ringbuffer); + +/** computes the space used in this ringbuffer + * @param ringbuffer + * @return the number of bytes stored in that ringbuffer + */ +size_t ringbuffer_used(const RingBuffer *ringbuffer); + +/** returns the capacity of the ring buffer + * @param ringbuffer + * @return the capacity of this ring buffer + */ +size_t ringbuffer_capacity(const RingBuffer *ringbuffer); + +/** writes some bytes in the ringbuffer, if the data doesn't fit, the ringbuffer + * is resized automatically + * + * @param rb the ringbuffer + * @param ptr a pointer on the data to add + * @param sz the size of the data to add + * @return if the operation was successful, it could fail in case of OOM during realloc() + */ +BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz); + + +/** ensures that we have sz bytes available at the write head, and return a pointer + * on the write head + * + * @param rb the ring buffer + * @param sz the size to ensure + * @return a pointer on the write head, or NULL in case of OOM + */ +BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz); + +/** move ahead the write head in case some byte were written directly by using + * a pointer retrieved via ringbuffer_ensure_linear_write(). This function is + * used to commit the written bytes. The provided size should not exceed the + * size ensured by ringbuffer_ensure_linear_write() + * + * @param rb the ring buffer + * @param sz the number of bytes that have been written + * @return if the operation was successful, FALSE is sz is too big + */ +BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz); + + +/** peeks the buffer chunks for sz bytes and returns how many chunks are filled. + * Note that the sum of the resulting chunks may be smaller than sz. + * + * @param rb the ringbuffer + * @param chunks an array of data chunks that will contain data / size of chunks + * @param sz the requested size + * @return the number of chunks used for reading sz bytes + */ +int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz); + +/** move ahead the read head in case some byte were read using ringbuffer_peek() + * This function is used to commit the bytes that were effectively consumed. + * + * @param rb the ring buffer + * @param sz the + */ +void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz); + + +#endif /* __RINGBUFFER_H___ */ diff --git a/libfreerdp/utils/CMakeLists.txt b/libfreerdp/utils/CMakeLists.txt index 716e96384..6e5858672 100644 --- a/libfreerdp/utils/CMakeLists.txt +++ b/libfreerdp/utils/CMakeLists.txt @@ -25,6 +25,7 @@ set(${MODULE_PREFIX}_SRCS pcap.c profiler.c rail.c + ringbuffer.c signal.c stopwatch.c svc_plugin.c @@ -68,3 +69,9 @@ else() endif() set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/libfreerdp") + + +if(BUILD_TESTING) + add_subdirectory(test) +endif() + diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c new file mode 100644 index 000000000..04dcffef1 --- /dev/null +++ b/libfreerdp/utils/ringbuffer.c @@ -0,0 +1,250 @@ +/** + * Copyright © 2014 Thincast Technologies GmbH + * Copyright © 2014 Hardening + * + * Permission to use, copy, modify, distribute, and sell this software and + * its documentation for any purpose is hereby granted without fee, provided + * that the above copyright notice appear in all copies and that both that + * copyright notice and this permission notice appear in supporting + * documentation, and that the name of the copyright holders not be used in + * advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. The copyright holders make + * no representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS + * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY + * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER + * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF + * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include + +#include +#include +#include + + +BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize) +{ + rb->buffer = malloc(initialSize); + if (!rb->buffer) + return FALSE; + + rb->readPtr = rb->writePtr = 0; + rb->initialSize = rb->size = rb->freeSize = initialSize; + return TRUE; +} + + +size_t ringbuffer_used(const RingBuffer *ringbuffer) +{ + return ringbuffer->size - ringbuffer->freeSize; +} + +size_t ringbuffer_capacity(const RingBuffer *ringbuffer) +{ + return ringbuffer->size; +} + +void ringbuffer_destroy(RingBuffer *ringbuffer) +{ + free(ringbuffer->buffer); + ringbuffer->buffer = 0; +} + +static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) +{ + BYTE *newData; + + if (rb->writePtr == rb->readPtr) + { + /* when no size is used we can realloc() and set the heads at the + * beginning of the buffer + */ + newData = (BYTE *)realloc(rb->buffer, targetSize); + if (!newData) + return FALSE; + rb->readPtr = rb->writePtr = 0; + } + else if ((rb->writePtr >= rb->readPtr) && (rb->writePtr < targetSize)) + { + /* we reallocate only if we're in that case, realloc don't touch read + * and write heads + * + * readPtr writePtr + * | | + * v v + * [............|XXXXXXXXXXXXXX|..........] + */ + newData = (BYTE *)realloc(rb->buffer, targetSize); + if (!newData) + return FALSE; + + rb->buffer = newData; + } + else + { + /* in case of malloc the read head is moved at the beginning of the new buffer + * and the write head is set accordingly + */ + newData = (BYTE *)malloc(targetSize); + if (!newData) + return FALSE; + if (rb->readPtr < rb->writePtr) + { + /* readPtr writePtr + * | | + * v v + * [............|XXXXXXXXXXXXXX|..........] + */ + memcpy(newData, rb->buffer + rb->readPtr, ringbuffer_used(rb)); + } + else + { + /* writePtr readPtr + * | | + * v v + * [XXXXXXXXXXXX|..............|XXXXXXXXXX] + */ + BYTE *dst = newData; + memcpy(dst, rb->buffer + rb->readPtr, rb->size - rb->readPtr); + dst += (rb->size - rb->readPtr); + if (rb->writePtr) + memcpy(dst, rb->buffer, rb->writePtr); + } + rb->writePtr = rb->size - rb->freeSize; + rb->readPtr = 0; + rb->buffer = newData; + } + + rb->freeSize += (targetSize - rb->size); + rb->size = targetSize; + return TRUE; +} + +/** + * + * @param rb + * @param ptr + * @param sz + * @return + */ +BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz) +{ + if ((rb->freeSize <= sz) && !ringbuffer_realloc(rb, rb->size + sz)) + return FALSE; + + /* the write could be split in two + * readHead writeHead + * | | + * v v + * [ ################ ] + */ + size_t toWrite = sz; + size_t remaining = sz; + if (rb->size - rb->writePtr < sz) + toWrite = rb->size - rb->writePtr; + + if (toWrite) + { + memcpy(rb->buffer + rb->writePtr, ptr, toWrite); + remaining -= toWrite; + ptr += toWrite; + } + + if (remaining) + memcpy(rb->buffer, ptr, remaining); + + rb->writePtr = (rb->writePtr + sz) % rb->size; + + rb->freeSize -= sz; + return TRUE; +} + + +BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz) +{ + if (rb->freeSize < sz) + { + if (!ringbuffer_realloc(rb, rb->size + sz - rb->freeSize + 32)) + return NULL; + } + + if (rb->writePtr == rb->readPtr) + { + rb->writePtr = rb->readPtr = 0; + } + + if (rb->writePtr + sz < rb->size) + return rb->buffer + rb->writePtr; + + /* + * to add: ....... + * [ XXXXXXXXX ] + * + * result: + * [XXXXXXXXX....... ] + */ + memmove(rb->buffer, rb->buffer + rb->readPtr, rb->writePtr - rb->readPtr); + rb->readPtr = 0; + rb->writePtr = rb->size - rb->freeSize; + return rb->buffer + rb->writePtr; +} + +BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz) +{ + if (rb->writePtr + sz > rb->size) + return FALSE; + rb->writePtr = (rb->writePtr + sz) % rb->size; + rb->freeSize -= sz; + return TRUE; +} + +int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz) +{ + size_t remaining = sz; + size_t toRead; + int chunkIndex = 0; + int ret = 0; + + if (rb->size - rb->freeSize < sz) + remaining = rb->size - rb->freeSize; + + toRead = remaining; + + if (rb->readPtr + remaining > rb->size) + toRead = rb->size - rb->readPtr; + + if (toRead) + { + chunks[0].data = rb->buffer + rb->readPtr; + chunks[0].size = toRead; + remaining -= toRead; + chunkIndex++; + ret++; + } + + if (remaining) + { + chunks[chunkIndex].data = rb->buffer; + chunks[chunkIndex].size = remaining; + ret++; + } + return ret; +} + +void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz) +{ + assert(rb->size - rb->freeSize >= sz); + + rb->readPtr = (rb->readPtr + sz) % rb->size; + rb->freeSize += sz; + + /* when we reach a reasonable free size, we can go back to the original size */ + if ((rb->size != rb->initialSize) && (ringbuffer_used(rb) < rb->initialSize / 2)) + ringbuffer_realloc(rb, rb->initialSize); +} diff --git a/libfreerdp/utils/test/CMakeLists.txt b/libfreerdp/utils/test/CMakeLists.txt new file mode 100644 index 000000000..2e8dbf153 --- /dev/null +++ b/libfreerdp/utils/test/CMakeLists.txt @@ -0,0 +1,34 @@ + +set(MODULE_NAME "TestFreeRDPutils") +set(MODULE_PREFIX "TEST_FREERDP_UTILS") + +set(${MODULE_PREFIX}_DRIVER ${MODULE_NAME}.c) + +set(${MODULE_PREFIX}_TESTS + TestRingBuffer.c +) + +create_test_sourcelist(${MODULE_PREFIX}_SRCS + ${${MODULE_PREFIX}_DRIVER} + ${${MODULE_PREFIX}_TESTS} +) + +add_executable(${MODULE_NAME} ${${MODULE_PREFIX}_SRCS}) + +set_complex_link_libraries(VARIABLE ${MODULE_PREFIX}_LIBS + MONOLITHIC ${MONOLITHIC_BUILD} + MODULE winpr + MODULES winpr-thread winpr-synch winpr-file winpr-utils winpr-crt freerdp-utils +) + +target_link_libraries(${MODULE_NAME} ${${MODULE_PREFIX}_LIBS}) + +set_target_properties(${MODULE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${TESTING_OUTPUT_DIRECTORY}") + +foreach(test ${${MODULE_PREFIX}_TESTS}) + get_filename_component(TestName ${test} NAME_WE) + add_test(${TestName} ${TESTING_OUTPUT_DIRECTORY}/${MODULE_NAME} ${TestName}) +endforeach() + +set_property(TARGET ${MODULE_NAME} PROPERTY FOLDER "FreeRDP/Test") + diff --git a/libfreerdp/utils/test/TestRingBuffer.c b/libfreerdp/utils/test/TestRingBuffer.c new file mode 100644 index 000000000..36cbaa559 --- /dev/null +++ b/libfreerdp/utils/test/TestRingBuffer.c @@ -0,0 +1,228 @@ +/** + * Copyright © 2014 Thincast Technologies GmbH + * Copyright © 2014 Hardening + * + * Permission to use, copy, modify, distribute, and sell this software and + * its documentation for any purpose is hereby granted without fee, provided + * that the above copyright notice appear in all copies and that both that + * copyright notice and this permission notice appear in supporting + * documentation, and that the name of the copyright holders not be used in + * advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. The copyright holders make + * no representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS + * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY + * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER + * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF + * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include + +#include + +BOOL test_overlaps(void) +{ + RingBuffer rb; + DataChunk chunks[2]; + BYTE bytes[200]; + int nchunks, i, j, k, counter = 0; + + for (i = 0; i < sizeof(bytes); i++) + bytes[i] = (BYTE)i; + + ringbuffer_init(&rb, 5); + if (!ringbuffer_write(&rb, bytes, 4)) /* [0123.] */ + goto error; + counter += 4; + ringbuffer_commit_read_bytes(&rb, 2); /* [..23.] */ + + if (!ringbuffer_write(&rb, &bytes[counter], 2)) /* [5.234] */ + goto error; + counter += 2; + + nchunks = ringbuffer_peek(&rb, chunks, 4); + if (nchunks != 2 || chunks[0].size != 3 || chunks[1].size != 1) + goto error; + + for (i = 0, j = 2; i < nchunks; i++) + { + for (k = 0; k < chunks[i].size; k++, j++) + { + if (chunks[i].data[k] != (BYTE)j) + goto error; + } + } + + ringbuffer_commit_read_bytes(&rb, 3); /* [5....] */ + if (ringbuffer_used(&rb) != 1) + goto error; + + if (!ringbuffer_write(&rb, &bytes[counter], 6)) /* [56789ab....] */ + goto error; + counter += 6; + + ringbuffer_commit_read_bytes(&rb, 6); /* [......b....] */ + nchunks = ringbuffer_peek(&rb, chunks, 10); + if (nchunks != 1 || chunks[0].size != 1 || (*chunks[0].data != 0xb)) + goto error; + + if (ringbuffer_capacity(&rb) != 5) + goto error; + + return TRUE; +error: + ringbuffer_destroy(&rb); + return FALSE; +} + + +int TestRingBuffer(int argc, char* argv[]) +{ + RingBuffer ringBuffer; + int testNo = 0; + BYTE *tmpBuf; + BYTE *rb_ptr; + int i/*, chunkNb, counter*/; + DataChunk chunks[2]; + + if (!ringbuffer_init(&ringBuffer, 10)) + { + fprintf(stderr, "unable to initialize ringbuffer\n"); + return -1; + } + + tmpBuf = (BYTE *)malloc(50); + if (!tmpBuf) + return -1; + + for (i = 0; i < 50; i++) + tmpBuf[i] = (char)i; + + fprintf(stderr, "%d: basic tests...", ++testNo); + if (!ringbuffer_write(&ringBuffer, tmpBuf, 5) || !ringbuffer_write(&ringBuffer, tmpBuf, 5) || + !ringbuffer_write(&ringBuffer, tmpBuf, 5)) + { + fprintf(stderr, "error when writing bytes\n"); + return -1; + } + + if (ringbuffer_used(&ringBuffer) != 15) + { + fprintf(stderr, "invalid used size got %d when i would expect 15\n", ringbuffer_used(&ringBuffer)); + return -1; + } + + if (ringbuffer_peek(&ringBuffer, chunks, 10) != 1 || chunks[0].size != 10) + { + fprintf(stderr, "error when reading bytes\n"); + return -1; + } + ringbuffer_commit_read_bytes(&ringBuffer, chunks[0].size); + + /* check retrieved bytes */ + for (i = 0; i < chunks[0].size; i++) + { + if (chunks[0].data[i] != i % 5) + { + fprintf(stderr, "invalid byte at %d, got %d instead of %d\n", i, chunks[0].data[i], i % 5); + return -1; + } + } + + if (ringbuffer_used(&ringBuffer) != 5) + { + fprintf(stderr, "invalid used size after read got %d when i would expect 5\n", ringbuffer_used(&ringBuffer)); + return -1; + } + + /* write some more bytes to have writePtr < readPtr and data splitted in 2 chunks */ + if (!ringbuffer_write(&ringBuffer, tmpBuf, 6) || + ringbuffer_peek(&ringBuffer, chunks, 11) != 2 || + chunks[0].size != 10 || + chunks[1].size != 1) + { + fprintf(stderr, "invalid read of splitted data\n"); + return -1; + } + + ringbuffer_commit_read_bytes(&ringBuffer, 11); + fprintf(stderr, "ok\n"); + + fprintf(stderr, "%d: peek with nothing to read...", ++testNo); + if (ringbuffer_peek(&ringBuffer, chunks, 10)) + { + fprintf(stderr, "peek returns some chunks\n"); + return -1; + } + fprintf(stderr, "ok\n"); + + fprintf(stderr, "%d: ensure_linear_write / read() shouldn't grow...", ++testNo); + for (i = 0; i < 1000; i++) + { + rb_ptr = ringbuffer_ensure_linear_write(&ringBuffer, 50); + if (!rb_ptr) + { + fprintf(stderr, "ringbuffer_ensure_linear_write() error\n"); + return -1; + } + + memcpy(rb_ptr, tmpBuf, 50); + + if (!ringbuffer_commit_written_bytes(&ringBuffer, 50)) + { + fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i); + return -1; + } + + //ringbuffer_commit_read_bytes(&ringBuffer, 25); + } + + for (i = 0; i < 1000; i++) + ringbuffer_commit_read_bytes(&ringBuffer, 25); + + for (i = 0; i < 1000; i++) + ringbuffer_commit_read_bytes(&ringBuffer, 25); + + + if (ringbuffer_capacity(&ringBuffer) != 10) + { + fprintf(stderr, "not the expected capacity, have %d and expects 10\n", ringbuffer_capacity(&ringBuffer)); + return -1; + } + fprintf(stderr, "ok\n"); + + + fprintf(stderr, "%d: free size is correctly computed...", ++testNo); + for (i = 0; i < 1000; i++) + { + ringbuffer_ensure_linear_write(&ringBuffer, 50); + if (!ringbuffer_commit_written_bytes(&ringBuffer, 50)) + { + fprintf(stderr, "ringbuffer_commit_written_bytes() error, i=%d\n", i); + return -1; + } + } + ringbuffer_commit_read_bytes(&ringBuffer, 50 * 1000); + fprintf(stderr, "ok\n"); + + ringbuffer_destroy(&ringBuffer); + + fprintf(stderr, "%d: specific overlaps test...", ++testNo); + if (!test_overlaps()) + { + fprintf(stderr, "ko\n", i); + return -1; + } + fprintf(stderr, "ok\n"); + return 0; +} + + + + From 9c18ae5bee4c8a508c22484defa81419ab47560c Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 10:19:13 +0200 Subject: [PATCH 02/11] Print function name when emiting an error --- libfreerdp/core/license.c | 8 ++++---- libfreerdp/core/update.c | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libfreerdp/core/license.c b/libfreerdp/core/license.c index dc96014a9..4d3a53a0f 100644 --- a/libfreerdp/core/license.c +++ b/libfreerdp/core/license.c @@ -241,7 +241,7 @@ int license_recv(rdpLicense* license, wStream* s) if (!rdp_read_header(license->rdp, s, &length, &channelId)) { - fprintf(stderr, "Incorrect RDP header.\n"); + fprintf(stderr, "%s: Incorrect RDP header.\n", __FUNCTION__); return -1; } @@ -252,7 +252,7 @@ int license_recv(rdpLicense* license, wStream* s) { if (!rdp_decrypt(license->rdp, s, length - 4, securityFlags)) { - fprintf(stderr, "rdp_decrypt failed\n"); + fprintf(stderr, "%s: rdp_decrypt failed\n", __FUNCTION__); return -1; } } @@ -268,7 +268,7 @@ int license_recv(rdpLicense* license, wStream* s) if (status < 0) { - fprintf(stderr, "Unexpected license packet.\n"); + fprintf(stderr, "%s: unexpected license packet.\n", __FUNCTION__); return status; } @@ -308,7 +308,7 @@ int license_recv(rdpLicense* license, wStream* s) break; default: - fprintf(stderr, "invalid bMsgType:%d\n", bMsgType); + fprintf(stderr, "%s: invalid bMsgType:%d\n", __FUNCTION__, bMsgType); return FALSE; } diff --git a/libfreerdp/core/update.c b/libfreerdp/core/update.c index b322fe753..15c5b9cf5 100644 --- a/libfreerdp/core/update.c +++ b/libfreerdp/core/update.c @@ -544,7 +544,7 @@ static void update_end_paint(rdpContext* context) if (update->numberOrders > 0) { - printf("Sending %d orders\n", update->numberOrders); + fprintf(stderr, "%s: sending %d orders\n", __FUNCTION__, update->numberOrders); fastpath_send_update_pdu(context->rdp->fastpath, FASTPATH_UPDATETYPE_ORDERS, s); } From 0376dcd065a5afc87ac9d4b2ef702b12f6fe09cd Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 15:54:25 +0200 Subject: [PATCH 03/11] Fix OOM situation --- libfreerdp/core/mcs.c | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/libfreerdp/core/mcs.c b/libfreerdp/core/mcs.c index 31620b8c8..8202d40b2 100644 --- a/libfreerdp/core/mcs.c +++ b/libfreerdp/core/mcs.c @@ -1056,26 +1056,29 @@ rdpMcs* mcs_new(rdpTransport* transport) { rdpMcs* mcs; - mcs = (rdpMcs*) malloc(sizeof(rdpMcs)); + mcs = (rdpMcs *)calloc(1, sizeof(rdpMcs)); + if (!mcs) + return NULL; - if (mcs) - { - ZeroMemory(mcs, sizeof(rdpMcs)); + mcs->transport = transport; + mcs->settings = transport->settings; - mcs->transport = transport; - mcs->settings = transport->settings; + mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF); + mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420); + mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF); + mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF); - mcs_init_domain_parameters(&mcs->targetParameters, 34, 2, 0, 0xFFFF); - mcs_init_domain_parameters(&mcs->minimumParameters, 1, 1, 1, 0x420); - mcs_init_domain_parameters(&mcs->maximumParameters, 0xFFFF, 0xFC17, 0xFFFF, 0xFFFF); - mcs_init_domain_parameters(&mcs->domainParameters, 0, 0, 0, 0xFFFF); - - mcs->channelCount = 0; - mcs->channelMaxCount = CHANNEL_MAX_COUNT; - mcs->channels = (rdpMcsChannel*) calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel)); - } + mcs->channelCount = 0; + mcs->channelMaxCount = CHANNEL_MAX_COUNT; + mcs->channels = (rdpMcsChannel *)calloc(mcs->channelMaxCount, sizeof(rdpMcsChannel)); + if (!mcs->channels) + goto out_free; return mcs; + +out_free: + free(mcs); + return NULL; } /** From dd6d82955087e1b53c8d9f1a7a5b252c8a545210 Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 17:32:14 +0200 Subject: [PATCH 04/11] Allow transport_write calls to be non-blocking This big patch allows to have non-blocking writes. To achieve this, it slightly changes the way transport is handled. The misc transport layers are handled with OpenSSL BIOs. In the chain we insert a bufferedBIO that will bufferize write calls that couldn't be honored. For an access with Tls security the BIO chain would look like this: FreeRdp Code ===> SSL bio ===> buffered BIO ===> socket BIO The buffered BIO will store bytes that couldn't be send because of blocking write calls. This patch also rework TSG so that it would look like this in the case of SSL security with TSG: (TSG in) > SSL BIO => buffered BIO ==> socket BIO / FreeRdp => SSL BIO => TSG BIO \ > SSL BIO => buffered BIO ==> socket BIO (TSG out) So from the FreeRDP point of view sending something is only BIO_writing on the frontBio (last BIO on the left). --- include/freerdp/crypto/tls.h | 13 +- include/freerdp/peer.h | 7 + include/freerdp/settings.h | 3 +- libfreerdp/core/gateway/http.c | 35 +- libfreerdp/core/gateway/http.h | 3 +- libfreerdp/core/gateway/ncacn_http.c | 33 +- libfreerdp/core/gateway/rpc.c | 175 +++--- libfreerdp/core/gateway/rpc.h | 4 +- libfreerdp/core/gateway/rpc_bind.c | 28 +- libfreerdp/core/gateway/rpc_client.c | 381 ++++++------ libfreerdp/core/gateway/rts.c | 84 ++- libfreerdp/core/gateway/rts_signature.c | 30 +- libfreerdp/core/gateway/tsg.c | 135 +++-- libfreerdp/core/peer.c | 37 +- libfreerdp/core/settings.c | 2 + libfreerdp/core/tcp.c | 308 ++++++++-- libfreerdp/core/tcp.h | 11 + libfreerdp/core/transport.c | 548 ++++++++++------- libfreerdp/core/transport.h | 4 + libfreerdp/crypto/tls.c | 750 ++++++++++-------------- 20 files changed, 1478 insertions(+), 1113 deletions(-) diff --git a/include/freerdp/crypto/tls.h b/include/freerdp/crypto/tls.h index bf5521300..180007e5e 100644 --- a/include/freerdp/crypto/tls.h +++ b/include/freerdp/crypto/tls.h @@ -70,7 +70,6 @@ struct rdp_tls SSL* ssl; BIO* bio; void* tsg; - int sockfd; SSL_CTX* ctx; BYTE* PublicKey; BIO_METHOD* methods; @@ -84,17 +83,11 @@ struct rdp_tls int alertDescription; }; -FREERDP_API int tls_connect(rdpTls* tls); -FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file); +FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying); +FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file); FREERDP_API BOOL tls_disconnect(rdpTls* tls); -FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length); -FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length); - -FREERDP_API int tls_wait_read(rdpTls* tls); -FREERDP_API int tls_wait_write(rdpTls* tls); +FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length); FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description); diff --git a/include/freerdp/peer.h b/include/freerdp/peer.h index c89d37a07..4fbe75bfc 100644 --- a/include/freerdp/peer.h +++ b/include/freerdp/peer.h @@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context); typedef BOOL (*psPeerInitialize)(freerdp_peer* client); typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount); typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client); +typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client); typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client); +typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client); +typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client); typedef BOOL (*psPeerClose)(freerdp_peer* client); typedef void (*psPeerDisconnect)(freerdp_peer* client); typedef BOOL (*psPeerCapabilities)(freerdp_peer* client); @@ -62,6 +65,7 @@ struct rdp_freerdp_peer psPeerInitialize Initialize; psPeerGetFileDescriptor GetFileDescriptor; psPeerGetEventHandle GetEventHandle; + psPeerGetReceiveEventHandle GetReceiveEventHandle; psPeerCheckFileDescriptor CheckFileDescriptor; psPeerClose Close; psPeerDisconnect Disconnect; @@ -81,6 +85,9 @@ struct rdp_freerdp_peer BOOL activated; BOOL authenticated; SEC_WINNT_AUTH_IDENTITY identity; + + psPeerIsWriteBlocked IsWriteBlocked; + psPeerDrainOutputBuffer DrainOutputBuffer; }; #ifdef __cplusplus diff --git a/include/freerdp/settings.h b/include/freerdp/settings.h index 6e921eb21..dab787581 100644 --- a/include/freerdp/settings.h +++ b/include/freerdp/settings.h @@ -798,7 +798,8 @@ struct rdp_settings ALIGN64 char* Password; /* 22 */ ALIGN64 char* Domain; /* 23 */ ALIGN64 char* PasswordHash; /* 24 */ - UINT64 padding0064[64 - 25]; /* 25 */ + ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */ + UINT64 padding0064[64 - 26]; /* 26 */ UINT64 padding0128[128 - 64]; /* 64 */ /** diff --git a/libfreerdp/core/gateway/http.c b/libfreerdp/core/gateway/http.c index c9f33f01a..610b23091 100644 --- a/libfreerdp/core/gateway/http.c +++ b/libfreerdp/core/gateway/http.c @@ -26,6 +26,10 @@ #include #include +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif + #include "http.h" HttpContext* http_context_new() @@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls) nbytes = 0; length = 10000; content = NULL; - buffer = malloc(length); + buffer = calloc(length, 1); if (!buffer) return NULL; @@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls) { while (nbytes < 5) { - status = tls_read(tls, p, length - nbytes); + status = BIO_read(tls->bio, p, length - nbytes); - if (status < 0) - goto out_error; + if (status <= 0) + { + if (!BIO_should_retry(tls->bio)) + goto out_error; - if (!status) + USleep(100); continue; + } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(p, status); +#endif nbytes += status; p = (BYTE*) &buffer[nbytes]; } @@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!header_end) { - fprintf(stderr, "http_response_recv: invalid response:\n"); + fprintf(stderr, "%s: invalid response:\n", __FUNCTION__); winpr_HexDump(buffer, status); goto out_error; } @@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls) header_end[0] = '\0'; header_end[1] = '\0'; - content = &header_end[2]; + content = header_end + 2; count = 0; line = (char*) buffer; @@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls) if (!http_response_parse_header(http_response)) goto out_error; - if (http_response->ContentLength > 0) + http_response->bodyLen = nbytes - (content - (char *)buffer); + if (http_response->bodyLen > 0) { - http_response->Content = _strdup(content); - if (!http_response->Content) + http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen); + if (!http_response->BodyContent) goto out_error; + + CopyMemory(http_response->BodyContent, content, http_response->bodyLen); } break; @@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response) ListDictionary_Free(http_response->Authenticates); if (http_response->ContentLength > 0) - free(http_response->Content); + free(http_response->BodyContent); free(http_response); } diff --git a/libfreerdp/core/gateway/http.h b/libfreerdp/core/gateway/http.h index 748b45a36..ded9ba214 100644 --- a/libfreerdp/core/gateway/http.h +++ b/libfreerdp/core/gateway/http.h @@ -84,7 +84,8 @@ struct _http_response wListDictionary *Authenticates; int ContentLength; - char* Content; + BYTE *BodyContent; + int bodyLen; }; void http_response_print(HttpResponse* http_response); diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c index 270dafbcf..b5beff4b2 100644 --- a/libfreerdp/core/gateway/ncacn_http.c +++ b/libfreerdp/core/gateway/ncacn_http.c @@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm; http_response = http_response_recv(rpc->TlsIn); + if (!http_response) + return -1; if (ListDictionary_Contains(http_response->Authenticates, "NTLM")) { @@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc) if (!token64) goto out; - ntlm_token_data = NULL; crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } +out: ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - -out: http_response_free(http_response); return 0; @@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel) rdpNtlm* ntlm = NULL; rdpSettings* settings = rpc->settings; freerdp* instance = (freerdp*) rpc->settings->instance; - BOOL promptPassword = FALSE; if (channel == TSG_CHANNEL_IN) ntlm = rpc->NtlmHttpIn->ntlm; else if (channel == TSG_CHANNEL_OUT) ntlm = rpc->NtlmHttpOut->ntlm; - if ((!settings->GatewayPassword) || (!settings->GatewayUsername) - || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) - { - promptPassword = TRUE; - } - - if (promptPassword) + if (!settings->GatewayPassword || !settings->GatewayUsername || + !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername)) { if (instance->GatewayAuthenticate) { - BOOL proceed = instance->GatewayAuthenticate(instance, - &settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain); + BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername, + &settings->GatewayPassword, &settings->GatewayDomain); if (!proceed) { @@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc) char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM"); crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length); } - ntlm->inputBuffer[0].pvBuffer = ntlm_token_data; ntlm->inputBuffer[0].cbBuffer = ntlm_token_length; - + http_response_free(http_response); - return 0; } @@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc) success = TRUE; /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); /* Receive OUT Channel Response */ - rpc_ncacn_http_recv_out_channel_response(rpc); /* Send OUT Channel Request */ - rpc_ncacn_http_send_out_channel_request(rpc); ntlm_client_uninit(ntlm); @@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL if (channel == TSG_CHANNEL_IN) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729"); } else if (channel == TSG_CHANNEL_OUT) { - http_context_set_pragma(ntlm_http->context, - "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", " + http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, " "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624"); } } diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c index c91a71071..2432ab06c 100644 --- a/libfreerdp/core/gateway/rpc.c +++ b/libfreerdp/core/gateway/rpc.c @@ -33,6 +33,11 @@ #include #include +#include + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include +#endif #include "http.h" #include "ntlm.h" @@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l { UINT32 alloc_hint = 0; rpcconn_hdr_t* header; + UINT32 frag_length; + UINT32 auth_length; + UINT32 auth_pad_length; + UINT32 sec_trailer_offset; + rpc_sec_trailer* sec_trailer; *offset = RPC_COMMON_FIELDS_LENGTH; header = ((rpcconn_hdr_t*) buffer); - if (header->common.ptype == PTYPE_RESPONSE) + switch (header->common.ptype) { - *offset += 8; - rpc_offset_align(offset, 8); - alloc_hint = header->response.alloc_hint; - } - else if (header->common.ptype == PTYPE_REQUEST) - { - *offset += 4; - rpc_offset_align(offset, 8); - alloc_hint = header->request.alloc_hint; - } - else if (header->common.ptype == PTYPE_RTS) - { - *offset += 4; - } - else - { - return FALSE; + case PTYPE_RESPONSE: + *offset += 8; + rpc_offset_align(offset, 8); + alloc_hint = header->response.alloc_hint; + break; + case PTYPE_REQUEST: + *offset += 4; + rpc_offset_align(offset, 8); + alloc_hint = header->request.alloc_hint; + break; + case PTYPE_RTS: + *offset += 4; + break; + default: + fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype); + return FALSE; } - if (length) + if (!length) + return TRUE; + + if (header->common.ptype == PTYPE_REQUEST) { - if (header->common.ptype == PTYPE_REQUEST) - { - UINT32 sec_trailer_offset; + UINT32 sec_trailer_offset; - sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; - *length = sec_trailer_offset - *offset; - } - else - { - UINT32 frag_length; - UINT32 auth_length; - UINT32 auth_pad_length; - UINT32 sec_trailer_offset; - rpc_sec_trailer* sec_trailer; + sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8; + *length = sec_trailer_offset - *offset; + return TRUE; + } - frag_length = header->common.frag_length; - auth_length = header->common.auth_length; - sec_trailer_offset = frag_length - auth_length - 8; - sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; - auth_pad_length = sec_trailer->auth_pad_length; + frag_length = header->common.frag_length; + auth_length = header->common.auth_length; + + sec_trailer_offset = frag_length - auth_length - 8; + sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset]; + auth_pad_length = sec_trailer->auth_pad_length; #if 0 - fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", - sec_trailer->auth_type, - sec_trailer->auth_level, - sec_trailer->auth_pad_length, - sec_trailer->auth_reserved, - sec_trailer->auth_context_id); + fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n", + sec_trailer->auth_type, + sec_trailer->auth_level, + sec_trailer->auth_pad_length, + sec_trailer->auth_reserved, + sec_trailer->auth_context_id); #endif - /** - * According to [MS-RPCE], auth_pad_length is the number of padding - * octets used to 4-byte align the security trailer, but in practice - * we get values up to 15, which indicates 16-byte alignment. - */ + /** + * According to [MS-RPCE], auth_pad_length is the number of padding + * octets used to 4-byte align the security trailer, but in practice + * we get values up to 15, which indicates 16-byte alignment. + */ - if ((frag_length - (sec_trailer_offset + 8)) != auth_length) - { - fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, - (frag_length - (sec_trailer_offset + 8))); - } - - *length = frag_length - auth_length - 24 - 8 - auth_pad_length; - } + if ((frag_length - (sec_trailer_offset + 8)) != auth_length) + { + fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length, + (frag_length - (sec_trailer_offset + 8))); } + *length = frag_length - auth_length - 24 - 8 - auth_pad_length; return TRUE; } @@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length) { int status; - status = tls_read(rpc->TlsOut, data, length); + status = BIO_read(rpc->TlsOut->bio, data, length); + /* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length, + * status, BIO_should_retry(rpc->TlsOut->bio)); */ + if (status > 0) { +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data, status); +#endif + return status; + } - return status; + if (BIO_should_retry(rpc->TlsOut->bio)) + return 0; + + return -1; } -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length) return status; } -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length) +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length) { int status; @@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) ntlm = rpc->ntlm; - if ((!ntlm) || (!ntlm->table)) + if (!ntlm || !ntlm->table) { - fprintf(stderr, "rpc_write: invalid ntlm context\n"); + fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__); return -1; } if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK) { - fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n"); + fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__); return -1; } - request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t)); - ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t)); + request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t)); + if (!request_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu); @@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->opnum = opnum; clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + goto out_free_pdu; + + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + goto out_free_clientCall; if (request_pdu->opnum == TsProxySetupReceivePipeOpnum) rpc->PipeCallId = request_pdu->call_id; @@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) request_pdu->frag_length = offset; - buffer = (BYTE*) malloc(request_pdu->frag_length); - + buffer = (BYTE*) calloc(1, request_pdu->frag_length); + if (!buffer) + goto out_free_pdu; CopyMemory(buffer, request_pdu, 24); offset = 24; @@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) Buffers[0].cbBuffer = offset; Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature; - Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer); - ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer); + Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer); + if (!Buffers[1].pvBuffer) + return -1; Message.cBuffers = 2; Message.ulVersion = SECBUFFER_VERSION; Message.pBuffers = (PSecBuffer) &Buffers; encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++); - if (encrypt_status != SEC_E_OK) { fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status); @@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum) offset += Buffers[1].cbBuffer; free(Buffers[1].pvBuffer); - if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0) + if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0) length = -1; free(request_pdu); return length; + +out_free_clientCall: + rpc_client_call_free(clientCall); +out_free_pdu: + free(request_pdu); + return -1; } BOOL rpc_connect(rdpRpc* rpc) @@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport) rpc->CallId = 2; - rpc_client_new(rpc); + if (rpc_client_new(rpc) < 0) + goto out_free_virtualConnectionCookieTable; rpc->client->SynchronousSend = TRUE; rpc->client->SynchronousReceive = TRUE; return rpc; +out_free_virtualConnectionCookieTable: + rpc_client_free(rpc); + ArrayList_Free(rpc->VirtualConnectionCookieTable); out_free_virtual_connection: rpc_client_virtual_connection_free(rpc->VirtualConnection); out_free_ntlm_http_out: diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h index d10d665c7..c86a8618f 100644 --- a/libfreerdp/core/gateway/rpc.h +++ b/libfreerdp/core/gateway/rpc.h @@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad); int rpc_out_read(rdpRpc* rpc, BYTE* data, int length); -int rpc_out_write(rdpRpc* rpc, BYTE* data, int length); -int rpc_in_write(rdpRpc* rpc, BYTE* data, int length); +int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length); +int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length); BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length); diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c index cf02a802a..ceae95159 100644 --- a/libfreerdp/core/gateway/rpc_bind.c +++ b/libfreerdp/core/gateway/rpc_bind.c @@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) DEBUG_RPC("Sending bind PDU"); rpc->ntlm = ntlm_new(); + if (!rpc->ntlm) + return -1; if ((!settings->GatewayPassword) || (!settings->GatewayUsername) || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername))) @@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc) settings->Username = _strdup(settings->GatewayUsername); settings->Domain = _strdup(settings->GatewayDomain); settings->Password = _strdup(settings->GatewayPassword); + + if (!settings->Username || !settings->Domain || settings->Password) + return -1; } } } - ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL); - ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname); + if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) || + !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) || + !ntlm_authenticate(rpc->ntlm) + ) + return -1; - ntlm_authenticate(rpc->ntlm); - - bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t)); - ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t)); + bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t)); + if (!bind_pdu) + return -1; rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu); @@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->p_context_elem.reserved2 = 0; bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem); + if (!bind_pdu->p_context_elem.p_cont_elem) + return -1; p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0]; @@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc) bind_pdu->frag_length = offset; buffer = (BYTE*) malloc(bind_pdu->frag_length); + if (!buffer) + return -1; CopyMemory(buffer, bind_pdu, 24); CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4); @@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc) length = bind_pdu->frag_length; clientCall = rpc_client_call_new(bind_pdu->call_id, 0); - ArrayList_Add(rpc->client->ClientCallList, clientCall); + if (!clientCall) + return -1; + if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0) + return -1; if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0) length = -1; diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c index dff88b3e5..c3613f6be 100644 --- a/libfreerdp/core/gateway/rpc_client.c +++ b/libfreerdp/core/gateway/rpc_client.c @@ -34,9 +34,7 @@ #include #include "rpc_fault.h" - #include "rpc_client.h" - #include "../rdp.h" #define SYNCHRONOUS_TIMEOUT 5000 @@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) if (!pdu) { - pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); + pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU)); + if (!pdu) + return NULL; pdu->s = Stream_New(NULL, rpc->max_recv_frag); + if (!pdu->s) + { + free(pdu); + return NULL; + } } pdu->CallId = 0; @@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu) { - Queue_Enqueue(rpc->client->ReceivePool, pdu); - return 0; + return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1; } int rpc_client_on_fragment_received_event(rdpRpc* rpc) @@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) rpcconn_hdr_t* header; freerdp* instance; - instance = (freerdp*) rpc->transport->settings->instance; + instance = (freerdp *)rpc->transport->settings->instance; if (!rpc->client->pdu) rpc->client->pdu = rpc_client_receive_pool_take(rpc); @@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) return 0; } - if (header->common.ptype == PTYPE_RTS) + switch (header->common.ptype) { - if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED) - { - //fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n"); + case PTYPE_RTS: + if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED) + { + fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__); + return 0; + } + fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__); rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length); - rpc_client_fragment_pool_return(rpc, fragment); - } - else - { - fprintf(stderr, "warning: unhandled RTS PDU\n"); - } + return 0; - return 0; - } - else if (header->common.ptype == PTYPE_FAULT) - { - rpc_recv_fault_pdu(header); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; - } - - if (header->common.ptype != PTYPE_RESPONSE) - { - fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype); - Queue_Enqueue(rpc->client->ReceiveQueue, NULL); - return -1; + case PTYPE_FAULT: + rpc_recv_fault_pdu(header); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; + case PTYPE_RESPONSE: + break; + default: + fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype); + Queue_Enqueue(rpc->client->ReceiveQueue, NULL); + return -1; } rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length; @@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) { - fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n"); + fprintf(stderr, "%s: expected stub\n", __FUNCTION__); Queue_Enqueue(rpc->client->ReceiveQueue, NULL); return -1; } @@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc) if (rpc->StubCallId != header->common.call_id) { - fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n", + fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__, rpc->StubCallId, header->common.call_id, rpc->StubFragCount); } @@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc) int status = -1; rpcconn_common_hdr_t* header; - if (!rpc->client->RecvFrag) - rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - - position = Stream_GetPosition(rpc->client->RecvFrag); - - if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + while (1) { - status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), - RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + if (!rpc->client->RecvFrag) + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); - if (status < 0) + position = Stream_GetPosition(rpc->client->RecvFrag); + + while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) { - fprintf(stderr, "rpc_client_frag_read: error reading header\n"); - return -1; + status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), + RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); + + if (status < 0) + { + fprintf(stderr, "rpc_client_frag_read: error reading header\n"); + return -1; + } + + if (!status) + return 0; + + Stream_Seek(rpc->client->RecvFrag, status); } - Stream_Seek(rpc->client->RecvFrag, status); - } + if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) + return status; + - if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH) - { header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag); if (header->frag_length > rpc->max_recv_frag) @@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc) return -1; } - if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) + while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) { status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), header->frag_length - Stream_GetPosition(rpc->client->RecvFrag)); if (status < 0) { - fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n"); + fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__); return -1; } + if (!status) + return 0; + Stream_Seek(rpc->client->RecvFrag, status); } - } - else - { - return status; - } - if (status < 0) - return -1; - - status = Stream_GetPosition(rpc->client->RecvFrag) - position; - - if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) - { - /* complete fragment received */ - - Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); - Stream_SetPosition(rpc->client->RecvFrag, 0); - - Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); - rpc->client->RecvFrag = NULL; - - if (rpc_client_on_fragment_received_event(rpc) < 0) + if (status < 0) return -1; + + status = Stream_GetPosition(rpc->client->RecvFrag) - position; + + if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) + { + /* complete fragment received */ + + Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); + Stream_SetPosition(rpc->client->RecvFrag, 0); + + Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); + rpc->client->RecvFrag = NULL; + + if (rpc_client_on_fragment_received_event(rpc) < 0) + return -1; + } } - return status; + return 0; } /** @@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum) RpcClientCall* clientCall; clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall)); + if (!clientCall) + return NULL; - if (clientCall) - { - clientCall->CallId = CallId; - clientCall->OpNum = OpNum; - clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; - } + clientCall->CallId = CallId; + clientCall->OpNum = OpNum; + clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; return clientCall; } @@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) int status; pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU)); - pdu->s = Stream_New(buffer, length); + if (!pdu) + return -1; - Queue_Enqueue(rpc->client->SendQueue, pdu); + pdu->s = Stream_New(buffer, length); + if (!pdu->s) + goto out_free; + + if (!Queue_Enqueue(rpc->client->SendQueue, pdu)) + goto out_free_stream; if (rpc->client->SynchronousSend) { status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT); if (status == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n"); + fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent); return -1; } @@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) } return 0; + +out_free_stream: + Stream_Free(pdu->s, TRUE); +out_free: + free(pdu); + return -1; } int rpc_send_dequeue_pdu(rdpRpc* rpc) @@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) RPC_PDU* pdu; RpcClientCall* clientCall; rpcconn_common_hdr_t* header; + RpcInChannel *inChannel; pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue); - if (!pdu) return 0; - WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE); + inChannel = rpc->VirtualConnection->DefaultInChannel; + WaitForSingleObject(inChannel->Mutex, INFINITE); status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); @@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) clientCall = rpc_client_call_find_by_id(rpc, header->call_id); clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; - ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex); + ReleaseMutex(inChannel->Mutex); /* * This protocol specifies that only RPC PDUs are subject to the flow control abstract @@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc) if (header->ptype == PTYPE_REQUEST) { - rpc->VirtualConnection->DefaultInChannel->BytesSent += status; - rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status; + inChannel->BytesSent += status; + inChannel->SenderAvailableWindow -= status; } Stream_Free(pdu->s, TRUE); @@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc) DWORD dwMilliseconds; DWORD result; - pdu = NULL; - dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; + dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); if (result == WAIT_TIMEOUT) { - fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n"); + fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__); return NULL; } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue); + if (result != WAIT_OBJECT_0) + return NULL; + + pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue); #ifdef WITH_DEBUG_TSG - if (pdu) - { - fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); - winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); - fprintf(stderr, "\n"); - } -#endif - - return pdu; + if (pdu) + { + fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); + winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); + fprintf(stderr, "\n"); } + else + { + fprintf(stderr, "Receiving a NULL PDU\n"); + } +#endif return pdu; } RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc) { - RPC_PDU* pdu; DWORD dwMilliseconds; DWORD result; - pdu = NULL; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); - if (result == WAIT_TIMEOUT) - { + if (result != WAIT_OBJECT_0) return NULL; - } - if (result == WAIT_OBJECT_0) - { - pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue); - return pdu; - } - - return pdu; + return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue); } static void* rpc_client_thread(void* arg) @@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg) DWORD nCount; HANDLE events[3]; HANDLE ReadEvent; + int fd; rpc = (rdpRpc*) arg; + fd = BIO_get_fd(rpc->TlsOut->bio, NULL); - ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd); + ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd); nCount = 0; events[nCount++] = rpc->client->StopEvent; events[nCount++] = Queue_Event(rpc->client->SendQueue); events[nCount++] = ReadEvent; + /* Do a first free run in case some bytes were set from the HTTP headers. + * We also have to do it because most of the time the underlying socket has notified, + * and the ssl layer has eaten all bytes, so we won't be notified any more even if the + * bytes are buffered locally + */ + if (rpc_client_on_read_event(rpc) < 0) + { + fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__); + goto out; + } + while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED) { status = WaitForMultipleObjects(nCount, events, FALSE, 100); - if (status != WAIT_TIMEOUT) + if (status == WAIT_TIMEOUT) + continue; + + if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) + break; + + if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) { - if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0) - { + if (rpc_client_on_read_event(rpc) < 0) break; - } + } - if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0) - { - if (rpc_client_on_read_event(rpc) < 0) - break; - } - - if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) - { - rpc_send_dequeue_pdu(rpc); - } + if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0) + { + rpc_send_dequeue_pdu(rpc); } } +out: CloseHandle(ReadEvent); return NULL; @@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg) static void rpc_pdu_free(RPC_PDU* pdu) { + if (!pdu) + return; + Stream_Free(pdu->s, TRUE); free(pdu); } @@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc) { RpcClient* client = NULL; - client = (RpcClient*) calloc(1, sizeof(RpcClient)); - - if (client) - { - client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - - client->SendQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->pdu = NULL; - client->ReceivePool = Queue_New(TRUE, -1, -1); - client->ReceiveQueue = Queue_New(TRUE, -1, -1); - Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; - - client->RecvFrag = NULL; - client->FragmentPool = Queue_New(TRUE, -1, -1); - client->FragmentQueue = Queue_New(TRUE, -1, -1); - - Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; - - client->ClientCallList = ArrayList_New(TRUE); - ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; - } - + client = (RpcClient *)calloc(1, sizeof(RpcClient)); rpc->client = client; + if (!client) + return -1; + client->Thread = CreateThread(NULL, 0, + (LPTHREAD_START_ROUTINE) rpc_client_thread, + rpc, CREATE_SUSPENDED, NULL); + if (!client->Thread) + return -1; + + client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->StopEvent) + return -1; + client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!client->PduSentEvent) + return -1; + + client->SendQueue = Queue_New(TRUE, -1, -1); + if (!client->SendQueue) + return -1; + Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->pdu = NULL; + client->ReceivePool = Queue_New(TRUE, -1, -1); + if (!client->ReceivePool) + return -1; + Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->ReceiveQueue = Queue_New(TRUE, -1, -1); + if (!client->ReceiveQueue) + return -1; + Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free; + + client->RecvFrag = NULL; + client->FragmentPool = Queue_New(TRUE, -1, -1); + if (!client->FragmentPool) + return -1; + Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->FragmentQueue = Queue_New(TRUE, -1, -1); + if (!client->FragmentQueue) + return -1; + Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free; + + client->ClientCallList = ArrayList_New(TRUE); + if (!client->ClientCallList) + return -1; + ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; return 0; } @@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc) rpc->client->Thread = NULL; } - rpc_client_free(rpc); - - return 0; + return rpc_client_free(rpc); } int rpc_client_free(rdpRpc* rpc) @@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc) client = rpc->client; - if (client) - { + if (!client) + return 0; + + if (client->SendQueue) Queue_Free(client->SendQueue); - if (client->RecvFrag) - rpc_fragment_free(client->RecvFrag); + if (client->RecvFrag) + rpc_fragment_free(client->RecvFrag); + if (client->FragmentPool) Queue_Free(client->FragmentPool); + if (client->FragmentQueue) Queue_Free(client->FragmentQueue); - if (client->pdu) - rpc_pdu_free(client->pdu); + if (client->pdu) + rpc_pdu_free(client->pdu); + if (client->ReceivePool) Queue_Free(client->ReceivePool); + if (client->ReceiveQueue) Queue_Free(client->ReceiveQueue); + if (client->ClientCallList) ArrayList_Free(client->ClientCallList); + if (client->StopEvent) CloseHandle(client->StopEvent); + if (client->PduSentEvent) CloseHandle(client->PduSentEvent); + if (client->Thread) CloseHandle(client->Thread); - free(client); - } - + free(client); return 0; } diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c index 42ce2ad4e..d57a4240d 100644 --- a/libfreerdp/core/gateway/rts.c +++ b/libfreerdp/core/gateway/rts.c @@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc) if (!rpc_ntlm_http_out_connect(rpc)) { - fprintf(stderr, "rpc_out_connect_http error!\n"); + fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__); return FALSE; } if (rts_send_CONN_A1_pdu(rpc) != 0) { - fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__); return FALSE; } if (!rpc_ntlm_http_in_connect(rpc)) { - fprintf(stderr, "rpc_in_connect_http error!\n"); + fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__); return FALSE; } - if (rts_send_CONN_B1_pdu(rpc) != 0) + if (rts_send_CONN_B1_pdu(rpc) < 0) { - fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n"); + fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__); return FALSE; } @@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc) */ http_response = http_response_recv(rpc->TlsOut); + if (!http_response) + { + fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__); + return FALSE; + } if (http_response->StatusCode != HTTP_STATUS_OK) { - fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode); + fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode); http_response_print(http_response); http_response_free(http_response); @@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc) return FALSE; } + if (http_response->bodyLen) + { + /* inject bytes we have read in the body as a received packet for the RPC client */ + rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc); + Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen); + CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen); + } + //http_response_print(http_response); http_response_free(http_response); @@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc) rpc_client_start(rpc); pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__); return FALSE; } @@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc) */ pdu = rpc_recv_dequeue_pdu(rpc); - if (!pdu) return FALSE; @@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc) if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts)) { - fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n"); + fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__); return FALSE; } @@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc) return TRUE; } -#if defined WITH_DEBUG_RTS && 0 +#ifdef WITH_DEBUG_RTS static const char* const RTS_CMD_STRINGS[] = { @@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] = void rts_pdu_header_init(rpcconn_rts_hdr_t* header) { + ZeroMemory(header, sizeof(*header)); header->rpc_vers = 5; header->rpc_vers_minor = 0; header->ptype = PTYPE_RTS; @@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc) ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) BYTE* INChannelCookie; BYTE* AssociationGroupId; BYTE* VirtualConnectionCookie; + int status; rts_pdu_header_init(&header); header.frag_length = 104; @@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_version_command_write(&buffer[20]); /* Version (8 bytes) */ @@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + status = rpc_in_write(rpc, buffer, length); free(buffer); - return 0; + return status; } /* CONN/C Sequence */ @@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Keep-Alive RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised; buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */ @@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc) length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return 0; @@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc) DEBUG_RPC("Sending Ping RTS PDU"); buffer = (BYTE*) malloc(header.frag_length); + if (!buffer) + return -1; CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */ length = header.frag_length; - rpc_in_write(rpc, buffer, length); + if (rpc_in_write(rpc, buffer, length) < 0) + return -1; free(buffer); return length; @@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) rts_extract_pdu_signature(rpc, &signature, rts); SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL); - if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK) + switch (SignatureId) { - return rts_recv_flow_control_ack_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION) - { - return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); - } - else if (SignatureId == RTS_PDU_PING) - { - rts_send_ping_pdu(rpc); - } - else - { - fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId); - rts_print_pdu_signature(rpc, &signature); + case RTS_PDU_FLOW_CONTROL_ACK: + return rts_recv_flow_control_ack_pdu(rpc, buffer, length); + case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION: + return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length); + case RTS_PDU_PING: + return rts_send_ping_pdu(rpc); + default: + fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId); + rts_print_pdu_signature(rpc, &signature); + break; } return 0; diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c index 34598fe71..47242ca63 100644 --- a/libfreerdp/core/gateway/rts_signature.c +++ b/libfreerdp/core/gateway/rts_signature.c @@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt return FALSE; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r signature->CommandTypes[i] = CommandType; status = rts_command_length(rpc, CommandType, &buffer[offset], length); - if (status < 0) return FALSE; @@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P { pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature; - if (signature->Flags == pSignature->Flags) + if (signature->Flags != pSignature->Flags) + continue; + + if (signature->NumberOfCommands != pSignature->NumberOfCommands) + continue; + + for (j = 0; j < signature->NumberOfCommands; j++) { - if (signature->NumberOfCommands == pSignature->NumberOfCommands) - { - for (j = 0; j < signature->NumberOfCommands; j++) - { - if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) - continue; - } - - if (entry) - *entry = &RTS_PDU_SIGNATURE_TABLE[i]; - - return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; - } + if (signature->CommandTypes[j] != pSignature->CommandTypes[j]) + continue; } + + if (entry) + *entry = &RTS_PDU_SIGNATURE_TABLE[i]; + + return RTS_PDU_SIGNATURE_TABLE[i].SignatureId; } return 0; diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c index f130f73ab..5dd68886d 100644 --- a/libfreerdp/core/gateway/tsg.c +++ b/libfreerdp/core/gateway/tsg.c @@ -33,9 +33,9 @@ #include #include "rpc_client.h" - #include "tsg.h" + /** * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/ * Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/ @@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count, } length = 28 + totalDataBytes; - buffer = (BYTE*) malloc(length); + buffer = (BYTE*) calloc(1, length); + if (!buffer) + return -1; s = Stream_New(buffer, length); @@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; // Skip Packet Pointer packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE)) { - packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE)); - ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE)); + packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE)); + if (!packetCapsResponse) // TODO: correct cleanup + return FALSE; packet->tsgPacket.packetCapsResponse = packetCapsResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) IsMessagePresent = *((UINT32*) &buffer[offset]); offset += 4; MessageSwitchValue = *((UINT32*) &buffer[offset]); - DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", - IsMessagePresent, MessageSwitchValue); + DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue); offset += 4; } @@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: correct cleanup + return FALSE; packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) /* 4-byte alignment */ rpc_offset_align(&offset, 4); - tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES)); - ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES)); + tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES)); + if (!tsgCaps) + return FALSE; + versionCaps->tsgCaps = tsgCaps; offset += 4; /* MaxCount (4 bytes) */ @@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) } else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE)) { - packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE)); - ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE)); + if (!packetQuarEncResponse) // TODO: handle cleanup + return FALSE; packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse; /* PacketQuarResponsePtr (4 bytes) */ @@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu) offset += 4; } - versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS)); - ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS)); + versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS)); + if (!versionCaps) // TODO: handle cleanup + return FALSE; packetQuarEncResponse->versionCaps = versionCaps; versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */ @@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu) if (!(pdu->Flags & RPC_PDU_FLAG_STUB)) buffer = &buffer[24]; - packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET)); - ZeroMemory(packet, sizeof(TSG_PACKET)); + packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET)); + if (!packet) + return FALSE; offset = 4; packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */ @@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI length = 60 + (count * 2); buffer = (BYTE*) malloc(length); + if (!buffer) + return FALSE; /* TunnelContext */ handle = (CONTEXT_HANDLE*) tunnelContext; @@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length) return CopyLength; } - else + + + tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->pdu) { - tsg->pdu = rpc_recv_peek_pdu(rpc); + if (!tsg->rpc->client->SynchronousReceive) + return 0; - if (!tsg->pdu) - { - if (tsg->rpc->client->SynchronousReceive) - return tsg_read(tsg, data, length); - else - return 0; - } - - tsg->PendingPdu = TRUE; - tsg->BytesAvailable = Stream_Length(tsg->pdu->s); - tsg->BytesRead = 0; - - CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; - - CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); - tsg->BytesAvailable -= CopyLength; - tsg->BytesRead += CopyLength; - - if (tsg->BytesAvailable < 1) - { - tsg->PendingPdu = FALSE; - rpc_recv_dequeue_pdu(rpc); - rpc_client_receive_pool_return(rpc, tsg->pdu); - } - - return CopyLength; + // weird !!!! + return tsg_read(tsg, data, length); } + + tsg->PendingPdu = TRUE; + tsg->BytesAvailable = Stream_Length(tsg->pdu->s); + tsg->BytesRead = 0; + + CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable; + + CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength); + tsg->BytesAvailable -= CopyLength; + tsg->BytesRead += CopyLength; + + if (tsg->BytesAvailable < 1) + { + tsg->PendingPdu = FALSE; + rpc_recv_dequeue_pdu(rpc); + rpc_client_receive_pool_return(rpc, tsg->pdu); + } + + return CopyLength; + } int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length) { + int status; + if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED) { - fprintf(stderr, "tsg_write error: connection lost\n"); + fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__); return -1; } - return TsProxySendToServer((handle_t) tsg, data, 1, &length); + status = TsProxySendToServer((handle_t) tsg, data, 1, &length); + if (status < 0) + return -1; + return length; } BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking) @@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport) { rdpTsg* tsg; - tsg = (rdpTsg*) malloc(sizeof(rdpTsg)); - ZeroMemory(tsg, sizeof(rdpTsg)); - - if (tsg != NULL) - { - tsg->transport = transport; - tsg->settings = transport->settings; - tsg->rpc = rpc_new(tsg->transport); - tsg->PendingPdu = FALSE; - } + tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg)); + if (!tsg) + return NULL; + tsg->transport = transport; + tsg->settings = transport->settings; + tsg->rpc = rpc_new(tsg->transport); + if (!tsg->rpc) + goto out_free; + tsg->PendingPdu = FALSE; return tsg; + +out_free: + free(tsg); + return NULL; } void tsg_free(rdpTsg* tsg) diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index e1662d335..bc7431f47 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client) fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile); return FALSE; } + if (settings->RdpServerRsaKey->ModulusLength > 256) { fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__); fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile); exit(1); } - } return TRUE; @@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client) return client->context->rdp->transport->TcpIn->event; } -static BOOL freerdp_peer_check_fds(freerdp_peer* client) + +static BOOL freerdp_peer_check_fds(freerdp_peer* peer) { int status; rdpRdp* rdp; - rdp = client->context->rdp; + rdp = peer->context->rdp; status = rdp_check_fds(rdp); @@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId return rdp_send_channel_data(client->context->rdp, channelId, data, size); } +static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer) +{ + return tranport_is_write_blocked(peer->context->rdp->transport); +} + +static int freerdp_peer_drain_output_buffer(freerdp_peer* peer) +{ + + rdpTransport *transport = peer->context->rdp->transport; + + return tranport_drain_output_buffer(transport); +} + void freerdp_peer_context_new(freerdp_peer* client) { rdpRdp* rdp; @@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client) rdp->transport->ReceiveExtra = client; transport_set_blocking_mode(rdp->transport, FALSE); + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; + IFCALL(client->ContextNew, client, client->context); } @@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd) client->Close = freerdp_peer_close; client->Disconnect = freerdp_peer_disconnect; client->SendChannelData = freerdp_peer_send_channel_data; + client->IsWriteBlocked = freerdp_peer_is_write_blocked; + client->DrainOutputBuffer = freerdp_peer_drain_output_buffer; } return client; @@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd) void freerdp_peer_free(freerdp_peer* client) { - if (client) - { - rdp_free(client->context->rdp); - free(client->context); - free(client); - } + if (!client) + return; + + rdp_free(client->context->rdp); + free(client->context); + free(client); } diff --git a/libfreerdp/core/settings.c b/libfreerdp/core/settings.c index 6538ec7cf..6bc2515f8 100644 --- a/libfreerdp/core/settings.c +++ b/libfreerdp/core/settings.c @@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags) ZeroMemory(settings, sizeof(rdpSettings)); settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE; + settings->WaitForOutputBufferFlush = TRUE; settings->DesktopWidth = 1024; settings->DesktopHeight = 768; @@ -579,6 +580,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings) /* BOOL values */ _settings->ServerMode = settings->ServerMode; /* 16 */ + _settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */ _settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */ _settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */ _settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */ diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index 15c417616..6676382fc 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -66,6 +66,165 @@ #include "tcp.h" +long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret) +{ + return 1; +} + +static int transport_bio_buffered_write(BIO* bio, const char* buf, int num) +{ + int status, ret; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int nchunks, committedBytes, i; + DataChunk chunks[2]; + + ret = num; + BIO_clear_retry_flags(bio); + tcp->writeBlocked = FALSE; + + /* we directly append extra bytes in the xmit buffer, this could be prevented + * but for now it makes the code more simple. + */ + if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num)) + { + fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num); + return -1; + } + + committedBytes = 0; + nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)); + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size); + /*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks, + chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status, + BIO_should_retry(bio->next_bio) + );*/ + if (status <= 0) + { + if (BIO_should_retry(bio->next_bio)) + { + tcp->writeBlocked = TRUE; + goto out; /* EWOULDBLOCK */ + } + + /* any other is an error, but we still have to commit written bytes */ + ret = -1; + goto out; + } + + committedBytes += status; + chunks[i].size -= status; + chunks[i].data += status; + } + } + +out: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes); + return ret; +} + +static int transport_bio_buffered_read(BIO* bio, char* buf, int size) +{ + int status; + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + tcp->readBlocked = FALSE; + BIO_clear_retry_flags(bio); + + status = BIO_read(bio->next_bio, buf, size); + /*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */ + + if (status <= 0 && BIO_should_retry(bio->next_bio)) + { + BIO_set_retry_read(bio); + tcp->readBlocked = TRUE; + } + + return status; +} + +static int transport_bio_buffered_puts(BIO* bio, const char* str) +{ + return 1; +} + +static int transport_bio_buffered_gets(BIO* bio, char* str, int size) +{ + return 1; +} + +static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + + switch (cmd) + { + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_WPENDING: + return ringbuffer_used(&tcp->xmitBuffer); + case BIO_CTRL_PENDING: + return 0; + default: + /*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */ + return BIO_ctrl(bio->next_bio, cmd, arg1, arg2); + } + + return 0; +} + +static int transport_bio_buffered_new(BIO* bio) +{ + bio->init = 1; + bio->num = 0; + bio->ptr = NULL; + bio->flags = 0; + + return 1; +} + +static int transport_bio_buffered_free(BIO* bio) +{ + return 1; +} + + +static BIO_METHOD transport_bio_buffered_socket_methods = +{ + BIO_TYPE_BUFFERED, + "BufferedSocket", + transport_bio_buffered_write, + transport_bio_buffered_read, + transport_bio_buffered_puts, + transport_bio_buffered_gets, + transport_bio_buffered_ctrl, + transport_bio_buffered_new, + transport_bio_buffered_free, + NULL, +}; + +BIO_METHOD* BIO_s_buffered_socket(void) +{ + return &transport_bio_buffered_socket_methods; +} + +BOOL transport_bio_buffered_drain(BIO *bio) +{ + rdpTcp *tcp = (rdpTcp *)bio->ptr; + int status; + + if (!ringbuffer_used(&tcp->xmitBuffer)) + return 1; + + status = transport_bio_buffered_write(bio, NULL, 0); + return status >= 0; +} + + + void tcp_get_ip_address(rdpTcp* tcp) { BYTE* ip; @@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port) if (hostname[0] == '/') { tcp->sockfd = freerdp_uds_connect(hostname); - if (tcp->sockfd < 0) return FALSE; + + tcp->socketBio = BIO_new_fd(tcp->sockfd, 1); + if (!tcp->socketBio) + return FALSE; } else { - tcp->sockfd = freerdp_tcp_connect(hostname, port); - - if (tcp->sockfd < 0) + tcp->socketBio = BIO_new(BIO_s_connect()); + if (!tcp->socketBio) return FALSE; - SetEventFileDescriptor(tcp->event, tcp->sockfd); + if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0) + return FALSE; - tcp_get_ip_address(tcp); - tcp_get_mac_address(tcp); + if (BIO_do_connect(tcp->socketBio) <= 0) + return FALSE; - option_value = 1; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len); - - /* receive buffer must be a least 32 K */ - if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) - { - if (option_value < (1024 * 32)) - { - option_value = 1024 * 32; - option_len = sizeof(option_value); - setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len); - } - } - - tcp_set_keep_alive_mode(tcp); + tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL); } + SetEventFileDescriptor(tcp->event, tcp->sockfd); + + tcp_get_ip_address(tcp); + tcp_get_mac_address(tcp); + + option_value = 1; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0) + fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__); + + /* receive buffer must be a least 32 K */ + if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0) + { + if (option_value < (1024 * 32)) + { + option_value = 1024 * 32; + option_len = sizeof(option_value); + if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0) + { + fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__); + return FALSE; + } + } + } + + if (!tcp_set_keep_alive_mode(tcp)) + return FALSE; + + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); return TRUE; } -int tcp_read(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_read(tcp->sockfd, data, length); -} - -int tcp_write(rdpTcp* tcp, BYTE* data, int length) -{ - return freerdp_tcp_write(tcp->sockfd, data, length); -} - -int tcp_wait_read(rdpTcp* tcp) -{ - return freerdp_tcp_wait_read(tcp->sockfd); -} - -int tcp_wait_write(rdpTcp* tcp) -{ - return freerdp_tcp_wait_write(tcp->sockfd); -} BOOL tcp_disconnect(rdpTcp* tcp) { @@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking) if (flags == -1) { - fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n"); + fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno)); return FALSE; } @@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd) { tcp->sockfd = sockfd; SetEventFileDescriptor(tcp->event, tcp->sockfd); + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer)); + + if (tcp->socketBio) + { + if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0) + return -1; + } + else + { + tcp->socketBio = BIO_new_socket(sockfd, 1); + if (!tcp->socketBio) + return -1; + } + + if (!tcp->bufferedBio) + { + tcp->bufferedBio = BIO_new(BIO_s_buffered_socket()); + if (!tcp->bufferedBio) + return FALSE; + tcp->bufferedBio->ptr = tcp; + + tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio); + } + return 0; } @@ -316,25 +503,34 @@ rdpTcp* tcp_new(rdpSettings* settings) { rdpTcp* tcp; - tcp = (rdpTcp*) malloc(sizeof(rdpTcp)); + tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp)); + if (!tcp) + return NULL; - if (tcp) - { - ZeroMemory(tcp, sizeof(rdpTcp)); + if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000)) + goto out_free; - tcp->sockfd = -1; - tcp->settings = settings; - tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); - } + tcp->sockfd = -1; + tcp->settings = settings; + + tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd); + if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE) + goto out_ringbuffer; return tcp; +out_ringbuffer: + ringbuffer_destroy(&tcp->xmitBuffer); +out_free: + free(tcp); + return NULL; } void tcp_free(rdpTcp* tcp) { - if (tcp) - { - CloseHandle(tcp->event); - free(tcp); - } + if (!tcp) + return; + + ringbuffer_destroy(&tcp->xmitBuffer); + CloseHandle(tcp->event); + free(tcp); } diff --git a/libfreerdp/core/tcp.h b/libfreerdp/core/tcp.h index b43fbaf1c..a8b3153b9 100644 --- a/libfreerdp/core/tcp.h +++ b/libfreerdp/core/tcp.h @@ -31,10 +31,15 @@ #include #include +#include +#include + #ifndef MSG_NOSIGNAL #define MSG_NOSIGNAL 0 #endif +#define BIO_TYPE_BUFFERED 66 + typedef struct rdp_tcp rdpTcp; struct rdp_tcp @@ -46,6 +51,12 @@ struct rdp_tcp #ifdef _WIN32 WSAEVENT wsa_event; #endif + BIO *socketBio; + BIO *bufferedBio; + RingBuffer xmitBuffer; + BOOL writeBlocked; + BOOL readBlocked; + HANDLE event; }; diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index c194c292c..f79d51aa5 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -33,7 +33,9 @@ #include #include +#include +#include #include #include #include @@ -41,6 +43,12 @@ #ifndef _WIN32 #include #include +#include +#include +#endif + +#ifdef HAVE_VALGRIND_MEMCHECK_H +#include #endif #include "tpkt.h" @@ -48,6 +56,7 @@ #include "transport.h" #include "rdp.h" + #define BUFFER_SIZE 16384 static void* transport_client_thread(void* arg); @@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd) tcp_attach(transport->TcpIn, sockfd); transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } void transport_stop(rdpTransport* transport) @@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport) transport_stop(transport); - if (transport->layer == TRANSPORT_LAYER_TLS) - status &= tls_disconnect(transport->TlsIn); - - if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS)) - { - status &= tsg_disconnect(transport->tsg); - } - else - { - status &= tcp_disconnect(transport->TcpIn); - } + BIO_free_all(transport->frontBio); + transport->frontBio = 0; return status; } @@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num) rdpTsg* tsg; tsg = (rdpTsg*) bio->ptr; - status = tsg_write(tsg, (BYTE*) buf, num); BIO_clear_retry_flags(bio); + status = tsg_write(tsg, (BYTE*) buf, num); + if (status > 0) + return status; if (status == 0) - { BIO_set_retry_write(bio); - } - return status < 0 ? 0 : num; + return -1; } static int transport_bio_tsg_read(BIO* bio, char* buf, int size) @@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void) return &transport_bio_tsg_methods; } + + BOOL transport_connect_tls(rdpTransport* transport) { + rdpSettings *settings = transport->settings; + rdpTls *targetTls; + BIO *targetBio; int tls_status; freerdp* instance; rdpContext* context; @@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport) if (transport->layer == TRANSPORT_LAYER_TSG) { transport->TsgTls = tls_new(transport->settings); - - transport->TsgTls->methods = BIO_s_tsg(); - transport->TsgTls->tsg = (void*) transport->tsg; - transport->layer = TRANSPORT_LAYER_TSG_TLS; - transport->TsgTls->hostname = transport->settings->ServerHostname; - transport->TsgTls->port = transport->settings->ServerPort; + targetTls = transport->TsgTls; + targetBio = transport->frontBio; + } + else + { + if (!transport->TlsIn) + transport->TlsIn = tls_new(settings); - if (transport->TsgTls->port == 0) - transport->TsgTls->port = 3389; + if (!transport->TlsOut) + transport->TlsOut = transport->TlsIn; - tls_status = tls_connect(transport->TsgTls); + targetTls = transport->TlsIn; + targetBio = transport->TcpIn->bufferedBio; - if (tls_status < 1) - { - if (tls_status < 0) - { - if (!connectErrorCode) - connectErrorCode = TLSCONNECTERROR; - - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED); - } - else - { - if (!freerdp_get_last_error(context)) - freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); - } - - tls_free(transport->TsgTls); - transport->TsgTls = NULL; - - return FALSE; - } - - return TRUE; + transport->layer = TRANSPORT_LAYER_TLS; } - if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - if (!transport->TlsOut) - transport->TlsOut = transport->TlsIn; + targetTls->hostname = settings->ServerHostname; + targetTls->port = settings->ServerPort; - transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; + if (targetTls->port == 0) + targetTls->port = 3389; - transport->TlsIn->hostname = transport->settings->ServerHostname; - transport->TlsIn->port = transport->settings->ServerPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 3389; - - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(targetTls, targetBio); if (tls_status < 1) { @@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport) freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED); } - tls_free(transport->TlsIn); - - if (transport->TlsIn == transport->TlsOut) - transport->TlsIn = transport->TlsOut = NULL; - else - transport->TlsIn = NULL; + return FALSE; + } + transport->frontBio = targetTls->bio; + if (!transport->frontBio) + { + fprintf(stderr, "%s: unable to prepend a filtering TLS bio"); return FALSE; } @@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport) { freerdp* instance; rdpSettings* settings; + rdpCredssp *credSsp; settings = transport->settings; instance = (freerdp*) settings->instance; @@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport) if (!transport->credssp) { transport->credssp = credssp_new(instance, transport, settings); + if (!transport->credssp) + return FALSE; + transport_set_nla_mode(transport, TRUE); if (settings->AuthenticationServiceClass) { transport->credssp->ServicePrincipalName = credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname); + if (!transport->credssp->ServicePrincipalName) + return FALSE; } } - if (credssp_authenticate(transport->credssp) < 0) + credSsp = transport->credssp; + if (credssp_authenticate(credSsp) < 0) { if (!connectErrorCode) connectErrorCode = AUTHENTICATIONERROR; @@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport) "If credentials are valid, the NTLMSSP implementation may be to blame.\n"); transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return FALSE; } transport_set_nla_mode(transport, FALSE); - credssp_free(transport->credssp); + credssp_free(credSsp); transport->credssp = NULL; return TRUE; @@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 int tls_status; freerdp* instance; rdpContext* context; + rdpSettings *settings = transport->settings; instance = (freerdp*) transport->settings->instance; context = instance->context; tsg = tsg_new(transport); + if (!tsg) + return FALSE; tsg->transport = transport; transport->tsg = tsg; transport->SplitInputOutput = TRUE; if (!transport->TlsIn) - transport->TlsIn = tls_new(transport->settings); - - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - transport->TlsIn->hostname = transport->settings->GatewayHostname; - transport->TlsIn->port = transport->settings->GatewayPort; - - if (transport->TlsIn->port == 0) - transport->TlsIn->port = 443; - + { + transport->TlsIn = tls_new(settings); + if (!transport->TlsIn) + return FALSE; + } if (!transport->TlsOut) - transport->TlsOut = tls_new(transport->settings); + { + transport->TlsOut = tls_new(settings); + if (!transport->TlsOut) + return FALSE; + } - transport->TlsOut->sockfd = transport->TcpOut->sockfd; - transport->TlsOut->hostname = transport->settings->GatewayHostname; - transport->TlsOut->port = transport->settings->GatewayPort; + /* put a decent default value for gateway port */ + if (!settings->GatewayPort) + settings->GatewayPort = 443; - if (transport->TlsOut->port == 0) - transport->TlsOut->port = 443; + transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname; + transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort; - tls_status = tls_connect(transport->TlsIn); + tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 return FALSE; } - tls_status = tls_connect(transport->TlsOut); - + tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio); if (tls_status < 1) { if (tls_status < 0) @@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16 if (!tsg_connect(tsg, hostname, port)) return FALSE; + transport->frontBio = BIO_new(BIO_s_tsg()); + transport->frontBio->ptr = tsg; return TRUE; } @@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por if (transport->GatewayEnabled) { transport->layer = TRANSPORT_LAYER_TSG; + transport->SplitInputOutput = TRUE; transport->TcpOut = tcp_new(settings); - status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort); + if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpIn, FALSE)) + return FALSE; - if (status) - status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort); + if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) || + !tcp_set_blocking_mode(transport->TcpOut, FALSE)) + return FALSE; - if (status) - status = transport_tsg_connect(transport, hostname, port); + if (!transport_tsg_connect(transport, hostname, port)) + return FALSE; + status = TRUE; } else { @@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por transport->SplitInputOutput = FALSE; transport->TcpOut = transport->TcpIn; + transport->frontBio = transport->TcpIn->bufferedBio; } if (status) @@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; return TRUE; } @@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport) transport->TlsOut = transport->TlsIn; transport->layer = TRANSPORT_LAYER_TLS; - transport->TlsIn->sockfd = transport->TcpIn->sockfd; - if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile)) + if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile)) return FALSE; + transport->frontBio = transport->TlsIn->bio; /* Network Level Authentication */ @@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s) return length; } +static int transport_wait_for_read(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpIn; + + tcpIn = transport->TcpIn; + if (tcpIn->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpIn->sockfd, rsetPtr); + } + else if (tcpIn->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpIn->sockfd, wsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + +static int transport_wait_for_write(rdpTransport* transport) +{ + struct timeval tv; + fd_set rset, wset; + fd_set *rsetPtr = NULL, *wsetPtr = NULL; + rdpTcp *tcpOut; + + tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn; + if (tcpOut->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(wsetPtr); + FD_SET(tcpOut->sockfd, wsetPtr); + } + else if (tcpOut->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(rsetPtr); + FD_SET(tcpOut->sockfd, rsetPtr); + } + + if (!wsetPtr && !rsetPtr) + { + USleep(1000); + return 0; + } + + tv.tv_sec = 0; + tv.tv_usec = 1000; + + return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); +} + + int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes) { int read = 0; int status = -1; + while (read < bytes) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_read(transport->TlsIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_read(transport->TcpIn, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_read(transport->tsg, data + read, bytes - read); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) { - status = tls_read(transport->TsgTls, data + read, bytes - read); + status = BIO_read(transport->frontBio, data + read, bytes - read); + + if (!status) + { + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; } - /* blocking means that we can't continue until this is read */ - - if (!transport->blocking) - return status; - if (status < 0) { - /* A read error indicates that the peer has dropped the connection */ - transport->layer = TRANSPORT_LAYER_CLOSED; - return status; + if (!BIO_should_retry(transport->frontBio)) + { + /* something unexpected happened, let's close */ + transport->layer = TRANSPORT_LAYER_CLOSED; + return -1; + } + + /* non blocking will survive a partial read */ + if (!transport->blocking) + return read; + + /* blocking means that we can't continue until we have read the number of + * requested bytes */ + if (transport_wait_for_read(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__); + return -1; + } + continue; } +#ifdef HAVE_VALGRIND_MEMCHECK_H + VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read); +#endif + read += status; - - if (status == 0) - { - /* - * instead of sleeping, we should wait timeout on the - * socket but this only happens on initial connection - */ - USleep(transport->SleepInterval); - } } return read; } + + int transport_read(rdpTransport* transport, wStream* s) { int status; int position; int pduLength; - BYTE header[4]; + BYTE *header; int transport_status; position = 0; @@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s) position += status; } - CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */ + header = Stream_Buffer(s); /* if header is present, read exactly one PDU */ @@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport) return status; } +BOOL transport_bio_buffered_drain(BIO *bio); + int transport_write(rdpTransport* transport, wStream* s) { int length; @@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s) while (length > 0) { - if (transport->layer == TRANSPORT_LAYER_TLS) - status = tls_write(transport->TlsOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TCP) - status = tcp_write(transport->TcpOut, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG) - status = tsg_write(transport->tsg, Stream_Pointer(s), length); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - status = tls_write(transport->TsgTls, Stream_Pointer(s), length); + status = BIO_write(transport->frontBio, Stream_Pointer(s), length); - if (status < 0) - break; /* error occurred */ - - if (status == 0) + if (status <= 0) { - /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */ - if (!transport->blocking) - { - /* and in case we do have buffered some data, we set the event so next loop will get it */ - if (transport_read_nonblocking(transport) > 0) - SetEvent(transport->ReceiveEvent); - } + /* the buffered BIO that is at the end of the chain always says OK for writing, + * so a retry means that for any reason we need to read. The most probable + * is a SSL or TSG BIO in the chain. + */ + if (!BIO_should_retry(transport->frontBio)) + return status; - if (transport->layer == TRANSPORT_LAYER_TLS) - tls_wait_write(transport->TlsOut); - else if (transport->layer == TRANSPORT_LAYER_TCP) - tcp_wait_write(transport->TcpOut); - else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) - tls_wait_write(transport->TsgTls); - else - USleep(transport->SleepInterval); + /* non-blocking can live with blocked IOs */ + if (!transport->blocking) + return status; + + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + continue; + } + + if (transport->blocking || transport->settings->WaitForOutputBufferFlush) + { + /* blocking transport, we must ensure the write buffer is really empty */ + rdpTcp *out = transport->TcpOut; + + while (out->writeBlocked) + { + if (transport_wait_for_write(transport) < 0) + { + fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__); + return -1; + } + + if (!transport_bio_buffered_drain(out->bufferedBio)) + { + fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__); + return -1; + } + } } length -= status; @@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* } } +BOOL tranport_is_write_blocked(rdpTransport* transport) +{ + if (transport->TcpIn->writeBlocked) + return TRUE; + + return transport->SplitInputOutput && + transport->TcpOut && + transport->TcpOut->writeBlocked; +} + +int tranport_drain_output_buffer(rdpTransport* transport) +{ + BOOL ret = FALSE; + + /* First try to send some accumulated bytes in the send buffer */ + if (transport->TcpIn->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio)) + return -1; + ret |= transport->TcpIn->writeBlocked; + } + + if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked) + { + if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio)) + return -1; + ret |= transport->TcpOut->writeBlocked; + } + + return ret; +} + int transport_check_fds(rdpTransport* transport) { int pos; @@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport) recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra); - Stream_Release(received); - - if (recv_status < 0) - return -1; - if (recv_status == 1) { return 1; /* session redirection */ } + Stream_Release(received); + + if (recv_status < 0) + return -1; } return 0; @@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings) { rdpTransport* transport; - transport = (rdpTransport*) malloc(sizeof(rdpTransport)); + transport = (rdpTransport *)calloc(1, sizeof(rdpTransport)); + if (!transport) + return NULL; - if (transport) - { - ZeroMemory(transport, sizeof(rdpTransport)); + WLog_Init(); + transport->log = WLog_Get("com.freerdp.core.transport"); + if (!transport->log) + goto out_free; - WLog_Init(); - transport->log = WLog_Get("com.freerdp.core.transport"); + transport->TcpIn = tcp_new(settings); + if (!transport->TcpIn) + goto out_free; - transport->TcpIn = tcp_new(settings); + transport->settings = settings; - transport->settings = settings; + /* a small 0.1ms delay when transport is blocking. */ + transport->SleepInterval = 100; - /* a small 0.1ms delay when transport is blocking. */ - transport->SleepInterval = 100; + transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + if (!transport->ReceivePool) + goto out_free_tcpin; - transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE); + /* receive buffer for non-blocking read. */ + transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); + if (!transport->ReceiveBuffer) + goto out_free_receivepool; - /* receive buffer for non-blocking read. */ - transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); - transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE) + goto out_free_receivebuffer; - transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE) + goto out_free_receiveEvent; - transport->blocking = TRUE; - transport->GatewayEnabled = FALSE; + transport->blocking = TRUE; + transport->GatewayEnabled = FALSE; + transport->layer = TRANSPORT_LAYER_TCP; - InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000); - InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000); - - transport->layer = TRANSPORT_LAYER_TCP; - } + if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000)) + goto out_free_connectedEvent; + if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000)) + goto out_free_readlock; return transport; + +out_free_readlock: + DeleteCriticalSection(&(transport->ReadLock)); +out_free_connectedEvent: + CloseHandle(transport->connectedEvent); +out_free_receiveEvent: + CloseHandle(transport->ReceiveEvent); +out_free_receivebuffer: + StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer); +out_free_receivepool: + StreamPool_Free(transport->ReceivePool); +out_free_tcpin: + tcp_free(transport->TcpIn); +out_free: + free(transport); + return NULL; } void transport_free(rdpTransport* transport) { - if (transport) - { - transport_stop(transport); + if (!transport) + return; - if (transport->ReceiveBuffer) - Stream_Release(transport->ReceiveBuffer); + transport_stop(transport); - StreamPool_Free(transport->ReceivePool); + if (transport->ReceiveBuffer) + Stream_Release(transport->ReceiveBuffer); - CloseHandle(transport->ReceiveEvent); - CloseHandle(transport->connectedEvent); + StreamPool_Free(transport->ReceivePool); - if (transport->TlsIn) - tls_free(transport->TlsIn); + CloseHandle(transport->ReceiveEvent); + CloseHandle(transport->connectedEvent); - if (transport->TlsOut != transport->TlsIn) - tls_free(transport->TlsOut); + if (transport->TlsIn) + tls_free(transport->TlsIn); - transport->TlsIn = NULL; - transport->TlsOut = NULL; + if (transport->TlsOut != transport->TlsIn) + tls_free(transport->TlsOut); - if (transport->TcpIn) - tcp_free(transport->TcpIn); + transport->TlsIn = NULL; + transport->TlsOut = NULL; - if (transport->TcpOut != transport->TcpIn) - tcp_free(transport->TcpOut); + if (transport->TcpIn) + tcp_free(transport->TcpIn); - transport->TcpIn = NULL; - transport->TcpOut = NULL; + if (transport->TcpOut != transport->TcpIn) + tcp_free(transport->TcpOut); - tsg_free(transport->tsg); - transport->tsg = NULL; + transport->TcpIn = NULL; + transport->TcpOut = NULL; - DeleteCriticalSection(&(transport->ReadLock)); - DeleteCriticalSection(&(transport->WriteLock)); + tsg_free(transport->tsg); + transport->tsg = NULL; - free(transport); - } + DeleteCriticalSection(&(transport->ReadLock)); + DeleteCriticalSection(&(transport->WriteLock)); + + free(transport); } diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index b8834ce7a..829807405 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport; #include #include + typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra); struct rdp_transport { TRANSPORT_LAYER layer; + BIO *frontBio; rdpTsg* tsg; rdpTcp* TcpIn; rdpTcp* TcpOut; @@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking); void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled); void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); +BOOL tranport_is_write_blocked(rdpTransport* transport); +BOOL tranport_drain_output_buffer(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport); int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); diff --git a/libfreerdp/crypto/tls.c b/libfreerdp/crypto/tls.c index 52c217782..016584fcc 100644 --- a/libfreerdp/crypto/tls.c +++ b/libfreerdp/crypto/tls.c @@ -28,34 +28,35 @@ #include #include +#include #include - -#ifdef HAVE_VALGRIND_MEMCHECK_H -#include -#endif +#include "../core/tcp.h" static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer) { CryptoCert cert; - X509* server_cert; + X509* remote_cert; if (peer) - server_cert = SSL_get_peer_certificate(tls->ssl); + remote_cert = SSL_get_peer_certificate(tls->ssl); else - server_cert = SSL_get_certificate(tls->ssl); + remote_cert = SSL_get_certificate(tls->ssl); - if (!server_cert) + if (!remote_cert) { - fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n"); - cert = NULL; - } - else - { - cert = malloc(sizeof(*cert)); - cert->px509 = server_cert; + fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__); + return NULL; } + cert = malloc(sizeof(*cert)); + if (!cert) + { + X509_free(remote_cert); + return NULL; + } + + cert->px509 = remote_cert; return cert; } @@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) PrefixLength = strlen(TLS_SERVER_END_POINT); ChannelBindingTokenLength = PrefixLength + CertificateHashLength; - ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings)); - ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings)); + ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings)); + if (!ContextBindings) + return NULL; ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength; - ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength); - ZeroMemory(ChannelBindings, ContextBindings->BindingsLength); + ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength); + if (!ChannelBindings) + goto out_free; ContextBindings->Bindings = ChannelBindings; ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength; @@ -99,32 +102,121 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert) CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength); return ContextBindings; + +out_free: + free(ContextBindings); + return NULL; } -static void tls_ssl_info_callback(const SSL* ssl, int type, int val) + +BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode) { - if (type & SSL_CB_HANDSHAKE_START) - { - - } -} - -int tls_connect(rdpTls* tls) -{ - CryptoCert cert; - long options = 0; - int verify_status; - int connection_status; - - tls->ctx = SSL_CTX_new(TLSv1_client_method()); - + tls->ctx = SSL_CTX_new(method); if (!tls->ctx) { - fprintf(stderr, "SSL_CTX_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__); + return FALSE; + } + + SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + + SSL_CTX_set_options(tls->ctx, options); + SSL_CTX_set_read_ahead(tls->ctx, 1); + + tls->bio = BIO_new_ssl(tls->ctx, clientMode); + if (BIO_get_ssl(tls->bio, &tls->ssl) < 0) + { + fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__); + return FALSE; + } + + BIO_push(tls->bio, underlying); + return TRUE; +} + +int tls_do_handshake(rdpTls* tls, BOOL clientMode) +{ + CryptoCert cert; + int verify_status, status; + + do + { + struct timeval tv; + fd_set rset; + int fd; + + status = BIO_do_handshake(tls->bio); + if (status == 1) + break; + if (!BIO_should_retry(tls->bio)) + return -1; + + /* we select() only for read even if we should test both read and write + * depending of what have blocked */ + FD_ZERO(&rset); + + fd = BIO_get_fd(tls->bio, NULL); + if (fd < 0) + { + fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__); + return -1; + } + + FD_SET(fd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 10 * 1000; /* 10ms */ + + status = select(fd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + { + fprintf(stderr, "%s: error during select()\n", __FUNCTION__); + return -1; + } + } + while (TRUE); + + if (!clientMode) + return 1; + + cert = tls_get_certificate(tls, clientMode); + if (!cert) + { + fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__); return -1; } - //SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE); + tls->Bindings = tls_get_channel_bindings(cert->px509); + if (!tls->Bindings) + { + fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__); + return -1; + } + + if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) + { + fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__); + tls_free_certificate(cert); + return -1; + } + + verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); + + if (verify_status < 1) + { + fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__); + tls_disconnect(tls); + tls_free_certificate(cert); + return 0; + } + + tls_free_certificate(cert); + + return verify_status; +} + +int tls_connect(rdpTls* tls, BIO *underlying) +{ + int options = 0; /** * SSL_OP_NO_COMPRESSION: @@ -138,7 +230,7 @@ int tls_connect(rdpTls* tls) #ifdef SSL_OP_NO_COMPRESSION options |= SSL_OP_NO_COMPRESSION; #endif - + /** * SSL_OP_TLS_BLOCK_PADDING_BUG: * @@ -155,96 +247,19 @@ int tls_connect(rdpTls* tls) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); + if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE)) + return FALSE; - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) - { - fprintf(stderr, "SSL_new failed\n"); - return -1; - } - - if (tls->tsg) - { - tls->bio = BIO_new(tls->methods); - - if (!tls->bio) - { - fprintf(stderr, "BIO_new failed\n"); - return -1; - } - - tls->bio->ptr = tls->tsg; - - SSL_set_bio(tls->ssl, tls->bio, tls->bio); - - SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback); - } - else - { - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return -1; - } - } - - connection_status = SSL_connect(tls->ssl); - - if (connection_status <= 0) - { - if (tls_print_error("SSL_connect", tls->ssl, connection_status)) - { - return -1; - } - } - - cert = tls_get_certificate(tls, TRUE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return -1; - } - - tls->Bindings = tls_get_channel_bindings(cert->px509); - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return -1; - } - - verify_status = tls_verify_certificate(tls, cert, tls->hostname, tls->port); - - if (verify_status < 1) - { - fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n"); - tls_disconnect(tls); - } - - tls_free_certificate(cert); - - return verify_status; + return tls_do_handshake(tls, TRUE); } -BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) + + +BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file) { - CryptoCert cert; long options = 0; - int connection_status; - tls->ctx = SSL_CTX_new(SSLv23_server_method()); - - if (tls->ctx == NULL) - { - fprintf(stderr, "SSL_CTX_new failed\n"); - return FALSE; - } - - /* + /** * SSL_OP_NO_SSLv2: * * We only want SSLv3 and TLSv1, so disable SSLv2. @@ -281,80 +296,23 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file) */ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS; - SSL_CTX_set_options(tls->ctx, options); - - if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0) - { - fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n"); - fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); + if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE)) return FALSE; - } - tls->ssl = SSL_new(tls->ctx); - - if (!tls->ssl) + if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_new failed\n"); + fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__); + fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file); return FALSE; } if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0) { - fprintf(stderr, "SSL_use_certificate_file failed\n"); + fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__); return FALSE; } - if (SSL_set_fd(tls->ssl, tls->sockfd) < 1) - { - fprintf(stderr, "SSL_set_fd failed\n"); - return FALSE; - } - - while (1) - { - connection_status = SSL_accept(tls->ssl); - - if (connection_status <= 0) - { - switch (SSL_get_error(tls->ssl, connection_status)) - { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - break; - - default: - if (tls_print_error("SSL_accept", tls->ssl, connection_status)) - return FALSE; - break; - - } - } - else - { - break; - } - } - - cert = tls_get_certificate(tls, FALSE); - - if (!cert) - { - fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n"); - return FALSE; - } - - if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength)) - { - fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n"); - tls_free_certificate(cert); - return FALSE; - } - - free(cert); - - fprintf(stderr, "TLS connection accepted\n"); - - return TRUE; + return tls_do_handshake(tls, FALSE) > 0; } BOOL tls_disconnect(rdpTls* tls) @@ -362,256 +320,161 @@ BOOL tls_disconnect(rdpTls* tls) if (!tls) return FALSE; - if (tls->ssl) + if (!tls->ssl) + return TRUE; + + if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) { - if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY) - { - /** - * OpenSSL doesn't really expose an API for sending a TLS alert manually. - * - * The following code disables the sending of the default "close notify" - * and then proceeds to force sending a custom TLS alert before shutting down. - * - * Manually sending a TLS alert is necessary in certain cases, - * like when server-side NLA results in an authentication failure. - */ + /** + * OpenSSL doesn't really expose an API for sending a TLS alert manually. + * + * The following code disables the sending of the default "close notify" + * and then proceeds to force sending a custom TLS alert before shutting down. + * + * Manually sending a TLS alert is necessary in certain cases, + * like when server-side NLA results in an authentication failure. + */ - SSL_set_quiet_shutdown(tls->ssl, 1); + SSL_set_quiet_shutdown(tls->ssl, 1); - if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) - SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); + if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session)) + SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session); - tls->ssl->s3->alert_dispatch = 1; - tls->ssl->s3->send_alert[0] = tls->alertLevel; - tls->ssl->s3->send_alert[1] = tls->alertDescription; + tls->ssl->s3->alert_dispatch = 1; + tls->ssl->s3->send_alert[0] = tls->alertLevel; + tls->ssl->s3->send_alert[1] = tls->alertDescription; - if (tls->ssl->s3->wbuf.left == 0) - tls->ssl->method->ssl_dispatch_alert(tls->ssl); + if (tls->ssl->s3->wbuf.left == 0) + tls->ssl->method->ssl_dispatch_alert(tls->ssl); - SSL_shutdown(tls->ssl); - } - else - { - SSL_shutdown(tls->ssl); - } + SSL_shutdown(tls->ssl); + } + else + { + SSL_shutdown(tls->ssl); } return TRUE; } -int tls_read(rdpTls* tls, BYTE* data, int length) + +BIO *findBufferedBio(BIO *front) { - int error; - int status; + BIO *ret = front; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_read(tls->ssl, data, length); - - if (status == 0) + while (ret) { - return -1; /* peer disconnected */ + if (BIO_method_type(ret) == BIO_TYPE_BUFFERED) + return ret; + ret = ret->next_bio; } - if (status <= 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n", - // length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - if (WSAGetLastError() == WSAEWOULDBLOCK) -#else - if ((errno == EAGAIN) || (errno == 0)) -#endif - { - status = 0; - } - else - { - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - } - break; - - default: - if (tls_print_error("SSL_read", tls->ssl, status)) - { - status = -1; - } - else - { - status = 0; - } - break; - } - } - -#ifdef HAVE_VALGRIND_MEMCHECK_H - VALGRIND_MAKE_MEM_DEFINED(data, status); -#endif - - return status; + return ret; } -int tls_write(rdpTls* tls, BYTE* data, int length) +int tls_write_all(rdpTls* tls, const BYTE* data, int length) { - int error; - int status; + int status, nchunks, commitedBytes; + rdpTcp *tcp; + fd_set rset, wset; + fd_set *rsetPtr, *wsetPtr; + struct timeval tv; + BIO *bio = tls->bio; + DataChunk chunks[2]; - if (!tls) - return -1; - - if (!tls->ssl) - return -1; - - status = SSL_write(tls->ssl, data, length); - - if (status == 0) + BIO *bufferedBio = findBufferedBio(bio); + if (!bufferedBio) { - return -1; /* peer disconnected */ + fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__); + return -1; } - if (status < 0) - { - error = SSL_get_error(tls->ssl, status); - - //fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error); - - switch (error) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - status = 0; - break; - - case SSL_ERROR_SYSCALL: - if (errno == EAGAIN) - { - status = 0; - } - else - { - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - } - break; - - default: - tls_print_error("SSL_write", tls->ssl, status); - status = -1; - break; - } - } - - return status; -} - -int tls_write_all(rdpTls* tls, BYTE* data, int length) -{ - int status; - int sent = 0; + tcp = (rdpTcp *)bufferedBio->ptr; do { - status = tls_write(tls, &data[sent], length - sent); - + status = BIO_write(bio, data, length); + /*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/ if (status > 0) - sent += status; - else if (status == 0) - tls_wait_write(tls); - - if (sent >= length) break; + + if (!BIO_should_retry(bio)) + return -1; + + /* we try to handle SSL want_read and want_write nicely */ + rsetPtr = wsetPtr = 0; + if (tcp->writeBlocked) + { + wsetPtr = &wset; + FD_ZERO(&wset); + FD_SET(tcp->sockfd, &wset); + } + else if (tcp->readBlocked) + { + rsetPtr = &rset; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + } + else + { + fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__); + USleep(10); + continue; + } + + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; + + status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv); + if (status < 0) + return -1; } - while (status >= 0); + while (TRUE); - if (status > 0) - return length; - else - return status; -} - -int tls_wait_read(rdpTls* tls) -{ - return freerdp_tcp_wait_read(tls->sockfd); -} - -int tls_wait_write(rdpTls* tls) -{ - return freerdp_tcp_wait_write(tls->sockfd); -} - -static void tls_errors(const char *prefix) -{ - unsigned long error; - - while ((error = ERR_get_error()) != 0) - fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL)); -} - -BOOL tls_print_error(char* func, SSL* connection, int value) -{ - switch (SSL_get_error(connection, value)) + /* make sure the output buffer is empty */ + commitedBytes = 0; + while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer)))) { - case SSL_ERROR_ZERO_RETURN: - fprintf(stderr, "%s: Server closed TLS connection\n", func); - return TRUE; + int i; - case SSL_ERROR_WANT_READ: - fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func); - return FALSE; + for (i = 0; i < nchunks; i++) + { + while (chunks[i].size) + { + status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size); + if (status > 0) + { + chunks[i].size -= status; + chunks[i].data += status; + commitedBytes += status; + continue; + } - case SSL_ERROR_WANT_WRITE: - fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func); - return FALSE; + if (!BIO_should_retry(tcp->socketBio)) + goto out_fail; + FD_ZERO(&rset); + FD_SET(tcp->sockfd, &rset); + tv.tv_sec = 0; + tv.tv_usec = 100 * 1000; - case SSL_ERROR_SYSCALL: -#ifdef _WIN32 - fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError()); -#else - fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno); -#endif - tls_errors(func); - return TRUE; + status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv); + if (status < 0) + goto out_fail; + } - case SSL_ERROR_SSL: - fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func); - tls_errors(func); - return TRUE; - - default: - fprintf(stderr, "%s: Unknown error\n", func); - tls_errors(func); - return TRUE; + } } + + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return length; + +out_fail: + ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes); + return -1; } + + int tls_set_alert_code(rdpTls* tls, int level, int description) { tls->alertLevel = level; @@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (!bio) { - fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n"); + fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__); return -1; } @@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status); + fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status); return -1; } @@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por if (status < 0) { - fprintf(stderr, "tls_verify_certificate: failed to read certificate\n"); + fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__); return -1; } @@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0); } - fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n", - length, status, pemCert); + fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert); free(pemCert); BIO_free(bio); @@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings) { rdpTls* tls; - tls = (rdpTls*) malloc(sizeof(rdpTls)); + tls = (rdpTls *)calloc(1, sizeof(rdpTls)); + if (!tls) + return NULL; - if (tls) - { - ZeroMemory(tls, sizeof(rdpTls)); + SSL_load_error_strings(); + SSL_library_init(); - SSL_load_error_strings(); - SSL_library_init(); - - tls->settings = settings; - tls->certificate_store = certificate_store_new(settings); - - tls->alertLevel = TLS_ALERT_LEVEL_WARNING; - tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; - } + tls->settings = settings; + tls->certificate_store = certificate_store_new(settings); + if (!tls->certificate_store) + goto out_free; + tls->alertLevel = TLS_ALERT_LEVEL_WARNING; + tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY; return tls; + +out_free: + free(tls); + return NULL; } void tls_free(rdpTls* tls) { - if (tls) + if (!tls) + return; + + if (tls->ctx) { - if (tls->ssl) - { - SSL_free(tls->ssl); - tls->ssl = NULL; - } - - if (tls->ctx) - { - SSL_CTX_free(tls->ctx); - tls->ctx = NULL; - } - - if (tls->PublicKey) - { - free(tls->PublicKey); - tls->PublicKey = NULL; - } - - if (tls->Bindings) - { - free(tls->Bindings->Bindings); - free(tls->Bindings); - tls->Bindings = NULL; - } - - certificate_store_free(tls->certificate_store); - tls->certificate_store = NULL; - - free(tls); + SSL_CTX_free(tls->ctx); + tls->ctx = NULL; } + + if (tls->PublicKey) + { + free(tls->PublicKey); + tls->PublicKey = NULL; + } + + if (tls->Bindings) + { + free(tls->Bindings->Bindings); + free(tls->Bindings); + tls->Bindings = NULL; + } + + certificate_store_free(tls->certificate_store); + tls->certificate_store = NULL; + + free(tls); } From 5234e05843941a6ba1dac2aec09a7592acd1fe60 Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 18:17:39 +0200 Subject: [PATCH 05/11] Make ringbuffer C89 aware for VC --- libfreerdp/utils/ringbuffer.c | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c index 04dcffef1..493445705 100644 --- a/libfreerdp/utils/ringbuffer.c +++ b/libfreerdp/utils/ringbuffer.c @@ -135,6 +135,9 @@ static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) */ BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz) { + size_t toWrite; + size_t remaining; + if ((rb->freeSize <= sz) && !ringbuffer_realloc(rb, rb->size + sz)) return FALSE; @@ -144,8 +147,8 @@ BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz) * v v * [ ################ ] */ - size_t toWrite = sz; - size_t remaining = sz; + toWrite = sz; + remaining = sz; if (rb->size - rb->writePtr < sz) toWrite = rb->size - rb->writePtr; From a04843bc9e1aaf35c477ade632ec42bee5c84796 Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 18:32:59 +0200 Subject: [PATCH 06/11] Fix some corner cases in ringbuffer and make unitary test have no leak --- libfreerdp/utils/ringbuffer.c | 2 ++ libfreerdp/utils/test/TestRingBuffer.c | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c index 493445705..95b5652fd 100644 --- a/libfreerdp/utils/ringbuffer.c +++ b/libfreerdp/utils/ringbuffer.c @@ -69,6 +69,7 @@ static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) if (!newData) return FALSE; rb->readPtr = rb->writePtr = 0; + rb->buffer = newData; } else if ((rb->writePtr >= rb->readPtr) && (rb->writePtr < targetSize)) { @@ -118,6 +119,7 @@ static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) } rb->writePtr = rb->size - rb->freeSize; rb->readPtr = 0; + free(rb->buffer); rb->buffer = newData; } diff --git a/libfreerdp/utils/test/TestRingBuffer.c b/libfreerdp/utils/test/TestRingBuffer.c index 36cbaa559..22e312e43 100644 --- a/libfreerdp/utils/test/TestRingBuffer.c +++ b/libfreerdp/utils/test/TestRingBuffer.c @@ -75,6 +75,7 @@ BOOL test_overlaps(void) if (ringbuffer_capacity(&rb) != 5) goto error; + ringbuffer_destroy(&rb); return TRUE; error: ringbuffer_destroy(&rb); @@ -220,6 +221,9 @@ int TestRingBuffer(int argc, char* argv[]) return -1; } fprintf(stderr, "ok\n"); + + ringbuffer_destroy(&ringBuffer); + free(tmpBuf); return 0; } From d8eb1f284f754d2e2f7ee3d8929bd0442fa6be1b Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 18:44:49 +0200 Subject: [PATCH 07/11] Updated license headers --- libfreerdp/utils/ringbuffer.c | 28 +++++++++++--------------- libfreerdp/utils/test/TestRingBuffer.c | 28 +++++++++++--------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c index 95b5652fd..1611490dc 100644 --- a/libfreerdp/utils/ringbuffer.c +++ b/libfreerdp/utils/ringbuffer.c @@ -1,24 +1,20 @@ /** + * FreeRDP: A Remote Desktop Protocol Implementation + * * Copyright © 2014 Thincast Technologies GmbH * Copyright © 2014 Hardening * - * Permission to use, copy, modify, distribute, and sell this software and - * its documentation for any purpose is hereby granted without fee, provided - * that the above copyright notice appear in all copies and that both that - * copyright notice and this permission notice appear in supporting - * documentation, and that the name of the copyright holders not be used in - * advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. The copyright holders make - * no representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS - * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY - * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER - * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF - * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN - * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include diff --git a/libfreerdp/utils/test/TestRingBuffer.c b/libfreerdp/utils/test/TestRingBuffer.c index 22e312e43..67497726e 100644 --- a/libfreerdp/utils/test/TestRingBuffer.c +++ b/libfreerdp/utils/test/TestRingBuffer.c @@ -1,24 +1,20 @@ /** + * FreeRDP: A Remote Desktop Protocol Implementation + * * Copyright © 2014 Thincast Technologies GmbH * Copyright © 2014 Hardening * - * Permission to use, copy, modify, distribute, and sell this software and - * its documentation for any purpose is hereby granted without fee, provided - * that the above copyright notice appear in all copies and that both that - * copyright notice and this permission notice appear in supporting - * documentation, and that the name of the copyright holders not be used in - * advertising or publicity pertaining to distribution of the software - * without specific, written prior permission. The copyright holders make - * no representations about the suitability of this software for any - * purpose. It is provided "as is" without express or implied warranty. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * THE COPYRIGHT HOLDERS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS - * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS, IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY - * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER - * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF - * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN - * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include From de1c08736f45958e08d795d0e22e12e522d5536c Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 19:12:51 +0200 Subject: [PATCH 08/11] Fix ringbuffer_write() to use const BYTE * instead of const void * --- include/freerdp/utils/ringbuffer.h | 2 +- libfreerdp/utils/ringbuffer.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/freerdp/utils/ringbuffer.h b/include/freerdp/utils/ringbuffer.h index 099ba8ba1..1ec00c514 100644 --- a/include/freerdp/utils/ringbuffer.h +++ b/include/freerdp/utils/ringbuffer.h @@ -76,7 +76,7 @@ size_t ringbuffer_capacity(const RingBuffer *ringbuffer); * @param sz the size of the data to add * @return if the operation was successful, it could fail in case of OOM during realloc() */ -BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz); +BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz); /** ensures that we have sz bytes available at the write head, and return a pointer diff --git a/libfreerdp/utils/ringbuffer.c b/libfreerdp/utils/ringbuffer.c index 1611490dc..07e770164 100644 --- a/libfreerdp/utils/ringbuffer.c +++ b/libfreerdp/utils/ringbuffer.c @@ -131,7 +131,7 @@ static BOOL ringbuffer_realloc(RingBuffer *rb, size_t targetSize) * @param sz * @return */ -BOOL ringbuffer_write(RingBuffer *rb, const void *ptr, size_t sz) +BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz) { size_t toWrite; size_t remaining; From 5c9a6408cfadaf89e5661239fef006691b5e9cdd Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 19:13:40 +0200 Subject: [PATCH 09/11] Fixed invalid declaration and missing argument --- libfreerdp/core/transport.c | 2 +- libfreerdp/core/transport.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index f79d51aa5..bb455a927 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -290,7 +290,7 @@ BOOL transport_connect_tls(rdpTransport* transport) transport->frontBio = targetTls->bio; if (!transport->frontBio) { - fprintf(stderr, "%s: unable to prepend a filtering TLS bio"); + fprintf(stderr, "%s: unable to prepend a filtering TLS bio", __FUNCTION__); return FALSE; } diff --git a/libfreerdp/core/transport.h b/libfreerdp/core/transport.h index 829807405..4e9f7e5a4 100644 --- a/libfreerdp/core/transport.h +++ b/libfreerdp/core/transport.h @@ -105,7 +105,7 @@ void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled) void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode); void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count); BOOL tranport_is_write_blocked(rdpTransport* transport); -BOOL tranport_drain_output_buffer(rdpTransport* transport); +int tranport_drain_output_buffer(rdpTransport* transport); wStream* transport_receive_pool_take(rdpTransport* transport); int transport_receive_pool_return(rdpTransport* transport, wStream* pdu); From 2b1a27b9b64020bba5fcf6b0b61c3153b42c1109 Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 22:18:38 +0200 Subject: [PATCH 10/11] Add .gitignore files for test generated files --- libfreerdp/crypto/test/.gitignore | 1 + libfreerdp/utils/test/.gitignore | 1 + 2 files changed, 2 insertions(+) create mode 100644 libfreerdp/crypto/test/.gitignore create mode 100644 libfreerdp/utils/test/.gitignore diff --git a/libfreerdp/crypto/test/.gitignore b/libfreerdp/crypto/test/.gitignore new file mode 100644 index 000000000..d425a5a86 --- /dev/null +++ b/libfreerdp/crypto/test/.gitignore @@ -0,0 +1 @@ +TestFreeRDPCrypto.c diff --git a/libfreerdp/utils/test/.gitignore b/libfreerdp/utils/test/.gitignore new file mode 100644 index 000000000..0e7faad57 --- /dev/null +++ b/libfreerdp/utils/test/.gitignore @@ -0,0 +1 @@ +TestFreeRDPutils.c From 3200baca4b12395571adaa3b29c48d020a798474 Mon Sep 17 00:00:00 2001 From: Hardening Date: Wed, 21 May 2014 22:20:38 +0200 Subject: [PATCH 11/11] Correctly export ringbuffer function and fix a warning --- include/freerdp/utils/ringbuffer.h | 28 +++++++++++++++++++--------- libfreerdp/core/tcp.c | 2 +- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/include/freerdp/utils/ringbuffer.h b/include/freerdp/utils/ringbuffer.h index 1ec00c514..b22b30786 100644 --- a/include/freerdp/utils/ringbuffer.h +++ b/include/freerdp/utils/ringbuffer.h @@ -25,6 +25,8 @@ #define __RINGBUFFER_H___ #include +#include + /** @brief ring buffer meta data */ struct _RingBuffer { @@ -45,28 +47,32 @@ struct _DataChunk { }; typedef struct _DataChunk DataChunk; +#ifdef __cplusplus +extern "C" { +#endif + /** initialise a ringbuffer * @param initialSize the initial capacity of the ringBuffer * @return if the initialisation was successful */ -BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize); +FREERDP_API BOOL ringbuffer_init(RingBuffer *rb, size_t initialSize); /** destroys internal data used by this ringbuffer * @param ringbuffer */ -void ringbuffer_destroy(RingBuffer *ringbuffer); +FREERDP_API void ringbuffer_destroy(RingBuffer *ringbuffer); /** computes the space used in this ringbuffer * @param ringbuffer * @return the number of bytes stored in that ringbuffer */ -size_t ringbuffer_used(const RingBuffer *ringbuffer); +FREERDP_API size_t ringbuffer_used(const RingBuffer *ringbuffer); /** returns the capacity of the ring buffer * @param ringbuffer * @return the capacity of this ring buffer */ -size_t ringbuffer_capacity(const RingBuffer *ringbuffer); +FREERDP_API size_t ringbuffer_capacity(const RingBuffer *ringbuffer); /** writes some bytes in the ringbuffer, if the data doesn't fit, the ringbuffer * is resized automatically @@ -76,7 +82,7 @@ size_t ringbuffer_capacity(const RingBuffer *ringbuffer); * @param sz the size of the data to add * @return if the operation was successful, it could fail in case of OOM during realloc() */ -BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz); +FREERDP_API BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz); /** ensures that we have sz bytes available at the write head, and return a pointer @@ -86,7 +92,7 @@ BOOL ringbuffer_write(RingBuffer *rb, const BYTE *ptr, size_t sz); * @param sz the size to ensure * @return a pointer on the write head, or NULL in case of OOM */ -BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz); +FREERDP_API BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz); /** move ahead the write head in case some byte were written directly by using * a pointer retrieved via ringbuffer_ensure_linear_write(). This function is @@ -97,7 +103,7 @@ BYTE *ringbuffer_ensure_linear_write(RingBuffer *rb, size_t sz); * @param sz the number of bytes that have been written * @return if the operation was successful, FALSE is sz is too big */ -BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz); +FREERDP_API BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz); /** peeks the buffer chunks for sz bytes and returns how many chunks are filled. @@ -108,7 +114,7 @@ BOOL ringbuffer_commit_written_bytes(RingBuffer *rb, size_t sz); * @param sz the requested size * @return the number of chunks used for reading sz bytes */ -int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz); +FREERDP_API int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz); /** move ahead the read head in case some byte were read using ringbuffer_peek() * This function is used to commit the bytes that were effectively consumed. @@ -116,7 +122,11 @@ int ringbuffer_peek(const RingBuffer *rb, DataChunk chunks[2], size_t sz); * @param rb the ring buffer * @param sz the */ -void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz); +FREERDP_API void ringbuffer_commit_read_bytes(RingBuffer *rb, size_t sz); +#ifdef __cplusplus +} +#endif + #endif /* __RINGBUFFER_H___ */ diff --git a/libfreerdp/core/tcp.c b/libfreerdp/core/tcp.c index 6676382fc..ee9e5099f 100644 --- a/libfreerdp/core/tcp.c +++ b/libfreerdp/core/tcp.c @@ -85,7 +85,7 @@ static int transport_bio_buffered_write(BIO* bio, const char* buf, int num) /* we directly append extra bytes in the xmit buffer, this could be prevented * but for now it makes the code more simple. */ - if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num)) + if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, (const BYTE *)buf, num)) { fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num); return -1;