Fixed calling of dump functions, updated API

This commit is contained in:
Armin Novak 2014-08-18 17:22:43 +02:00
parent 28ece6bb46
commit f8eae11bf3
6 changed files with 559 additions and 732 deletions

View File

@ -2871,70 +2871,70 @@ static unsigned long next = 1;
static int simple_rand(void) static int simple_rand(void)
{ {
next = next * 1103515245 + 12345; next = next * 1103515245 + 12345;
return ((unsigned int) (next / 65536) % 32768); return ((unsigned int)(next / 65536) % 32768);
} }
static void fill_bitmap_alpha_channel(BYTE* data, int width, int height, BYTE value) static void fill_bitmap_alpha_channel(BYTE *data, int width, int height, BYTE value)
{ {
int i, j; int i, j;
UINT32* pixel; UINT32 *pixel;
for (i = 0; i < height; i++) for (i = 0; i < height; i++)
{ {
for (j = 0; j < width; j++) for (j = 0; j < width; j++)
{ {
pixel = (UINT32*) &data[((i * width) + j) * 4]; pixel = (UINT32 *) &data[((i * width) + j) * 4];
*pixel = ((*pixel & 0x00FFFFFF) | (value << 24)); *pixel = ((*pixel & 0x00FFFFFF) | (value << 24));
} }
} }
} }
void fill_bitmap_red_channel(BYTE* data, int width, int height, BYTE value) void fill_bitmap_red_channel(BYTE *data, int width, int height, BYTE value)
{ {
int i, j; int i, j;
UINT32* pixel; UINT32 *pixel;
for (i = 0; i < height; i++) for (i = 0; i < height; i++)
{ {
for (j = 0; j < width; j++) for (j = 0; j < width; j++)
{ {
pixel = (UINT32*) &data[((i * width) + j) * 4]; pixel = (UINT32 *) &data[((i * width) + j) * 4];
*pixel = ((*pixel & 0xFF00FFFF) | (value << 16)); *pixel = ((*pixel & 0xFF00FFFF) | (value << 16));
} }
} }
} }
void fill_bitmap_green_channel(BYTE* data, int width, int height, BYTE value) void fill_bitmap_green_channel(BYTE *data, int width, int height, BYTE value)
{ {
int i, j; int i, j;
UINT32* pixel; UINT32 *pixel;
for (i = 0; i < height; i++) for (i = 0; i < height; i++)
{ {
for (j = 0; j < width; j++) for (j = 0; j < width; j++)
{ {
pixel = (UINT32*) &data[((i * width) + j) * 4]; pixel = (UINT32 *) &data[((i * width) + j) * 4];
*pixel = ((*pixel & 0xFFFF00FF) | (value << 8)); *pixel = ((*pixel & 0xFFFF00FF) | (value << 8));
} }
} }
} }
void fill_bitmap_blue_channel(BYTE* data, int width, int height, BYTE value) void fill_bitmap_blue_channel(BYTE *data, int width, int height, BYTE value)
{ {
int i, j; int i, j;
UINT32* pixel; UINT32 *pixel;
for (i = 0; i < height; i++) for (i = 0; i < height; i++)
{ {
for (j = 0; j < width; j++) for (j = 0; j < width; j++)
{ {
pixel = (UINT32*) &data[((i * width) + j) * 4]; pixel = (UINT32 *) &data[((i * width) + j) * 4];
*pixel = ((*pixel & 0xFFFFFF00) | (value)); *pixel = ((*pixel & 0xFFFFFF00) | (value));
} }
} }
} }
void dump_color_channel(BYTE* data, int width, int height) void dump_color_channel(BYTE *data, int width, int height)
{ {
int i, j; int i, j;
@ -2943,7 +2943,7 @@ void dump_color_channel(BYTE* data, int width, int height)
for (j = 0; j < width; j++) for (j = 0; j < width; j++)
{ {
printf("%02X%s", *data, printf("%02X%s", *data,
((j + 1) == width)? "\n" : " "); ((j + 1) == width)? "\n" : " ");
data += 4; data += 4;
} }
} }
@ -2953,35 +2953,28 @@ int test_individual_planes_encoding_rle()
{ {
int width; int width;
int height; int height;
BYTE* pOutput; BYTE *pOutput;
int planeSize; int planeSize;
int compareSize; int compareSize;
int dstSizes[4]; int dstSizes[4];
int availableSize; int availableSize;
DWORD planarFlags; DWORD planarFlags;
BITMAP_PLANAR_CONTEXT* planar; BITMAP_PLANAR_CONTEXT *planar;
planarFlags = PLANAR_FORMAT_HEADER_NA; planarFlags = PLANAR_FORMAT_HEADER_NA;
planarFlags |= PLANAR_FORMAT_HEADER_RLE; planarFlags |= PLANAR_FORMAT_HEADER_RLE;
width = 64; width = 64;
height = 64; height = 64;
planeSize = width * height; planeSize = width * height;
planar = freerdp_bitmap_planar_context_new(planarFlags, width, height); planar = freerdp_bitmap_planar_context_new(planarFlags, width, height);
CopyMemory(planar->planes[1], (BYTE *) TEST_64X64_RED_PLANE, planeSize); /* Red */
CopyMemory(planar->planes[1], (BYTE*) TEST_64X64_RED_PLANE, planeSize); /* Red */ CopyMemory(planar->planes[2], (BYTE *) TEST_64X64_GREEN_PLANE, planeSize); /* Green */
CopyMemory(planar->planes[2], (BYTE*) TEST_64X64_GREEN_PLANE, planeSize); /* Green */ CopyMemory(planar->planes[3], (BYTE *) TEST_64X64_BLUE_PLANE, planeSize); /* Blue */
CopyMemory(planar->planes[3], (BYTE*) TEST_64X64_BLUE_PLANE, planeSize); /* Blue */
freerdp_bitmap_planar_delta_encode_plane(planar->planes[1], width, height, planar->deltaPlanes[1]); /* Red */ freerdp_bitmap_planar_delta_encode_plane(planar->planes[1], width, height, planar->deltaPlanes[1]); /* Red */
freerdp_bitmap_planar_delta_encode_plane(planar->planes[2], width, height, planar->deltaPlanes[2]); /* Green */ freerdp_bitmap_planar_delta_encode_plane(planar->planes[2], width, height, planar->deltaPlanes[2]); /* Green */
freerdp_bitmap_planar_delta_encode_plane(planar->planes[3], width, height, planar->deltaPlanes[3]); /* Blue */ freerdp_bitmap_planar_delta_encode_plane(planar->planes[3], width, height, planar->deltaPlanes[3]); /* Blue */
pOutput = planar->rlePlanesBuffer; pOutput = planar->rlePlanesBuffer;
availableSize = planeSize * 3; availableSize = planeSize * 3;
/* Red */ /* Red */
dstSizes[1] = availableSize; dstSizes[1] = availableSize;
if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[1], width, height, pOutput, &dstSizes[1])) if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[1], width, height, pOutput, &dstSizes[1]))
@ -2997,27 +2990,23 @@ int test_individual_planes_encoding_rle()
if (dstSizes[1] != sizeof(TEST_64X64_RED_PLANE_RLE)) if (dstSizes[1] != sizeof(TEST_64X64_RED_PLANE_RLE))
{ {
printf("RedPlaneRle unexpected size: actual: %d, expected: %d\n", printf("RedPlaneRle unexpected size: actual: %d, expected: %d\n",
dstSizes[1], (int) sizeof(TEST_64X64_RED_PLANE_RLE)); dstSizes[1], (int) sizeof(TEST_64X64_RED_PLANE_RLE));
//return -1; //return -1;
} }
compareSize = (dstSizes[1] > sizeof(TEST_64X64_RED_PLANE_RLE)) ? sizeof(TEST_64X64_RED_PLANE_RLE) : dstSizes[1]; compareSize = (dstSizes[1] > sizeof(TEST_64X64_RED_PLANE_RLE)) ? sizeof(TEST_64X64_RED_PLANE_RLE) : dstSizes[1];
if (memcmp(planar->rlePlanes[1], (BYTE*) TEST_64X64_RED_PLANE_RLE, compareSize) != 0) if (memcmp(planar->rlePlanes[1], (BYTE *) TEST_64X64_RED_PLANE_RLE, compareSize) != 0)
{ {
printf("RedPlaneRle doesn't match expected output\n"); printf("RedPlaneRle doesn't match expected output\n");
printf("RedPlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_RED_PLANE_RLE)); printf("RedPlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_RED_PLANE_RLE));
//winpr_HexDump((BYTE*) TEST_64X64_RED_PLANE_RLE, sizeof(TEST_64X64_RED_PLANE_RLE)); //winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE*) TEST_64X64_RED_PLANE_RLE, sizeof(TEST_64X64_RED_PLANE_RLE));
printf("RedPlaneRle Actual (%d):\n", dstSizes[1]); printf("RedPlaneRle Actual (%d):\n", dstSizes[1]);
//winpr_HexDump(planar->rlePlanes[1], dstSizes[1]); //winpr_HexDump("codec.test", WLOG_DEBUG, planar->rlePlanes[1], dstSizes[1]);
return -1; return -1;
} }
/* Green */ /* Green */
dstSizes[2] = availableSize; dstSizes[2] = availableSize;
if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[2], width, height, pOutput, &dstSizes[2])) if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[2], width, height, pOutput, &dstSizes[2]))
@ -3033,27 +3022,23 @@ int test_individual_planes_encoding_rle()
if (dstSizes[2] != sizeof(TEST_64X64_GREEN_PLANE_RLE)) if (dstSizes[2] != sizeof(TEST_64X64_GREEN_PLANE_RLE))
{ {
printf("GreenPlaneRle unexpected size: actual: %d, expected: %d\n", printf("GreenPlaneRle unexpected size: actual: %d, expected: %d\n",
dstSizes[1], (int) sizeof(TEST_64X64_GREEN_PLANE_RLE)); dstSizes[1], (int) sizeof(TEST_64X64_GREEN_PLANE_RLE));
return -1; return -1;
} }
compareSize = (dstSizes[2] > sizeof(TEST_64X64_GREEN_PLANE_RLE)) ? sizeof(TEST_64X64_GREEN_PLANE_RLE) : dstSizes[2]; compareSize = (dstSizes[2] > sizeof(TEST_64X64_GREEN_PLANE_RLE)) ? sizeof(TEST_64X64_GREEN_PLANE_RLE) : dstSizes[2];
if (memcmp(planar->rlePlanes[2], (BYTE*) TEST_64X64_GREEN_PLANE_RLE, compareSize) != 0) if (memcmp(planar->rlePlanes[2], (BYTE *) TEST_64X64_GREEN_PLANE_RLE, compareSize) != 0)
{ {
printf("GreenPlaneRle doesn't match expected output\n"); printf("GreenPlaneRle doesn't match expected output\n");
printf("GreenPlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_GREEN_PLANE_RLE)); printf("GreenPlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_GREEN_PLANE_RLE));
winpr_HexDump((BYTE*) TEST_64X64_GREEN_PLANE_RLE, (int) sizeof(TEST_64X64_GREEN_PLANE_RLE)); winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE *) TEST_64X64_GREEN_PLANE_RLE, (int) sizeof(TEST_64X64_GREEN_PLANE_RLE));
printf("GreenPlaneRle Actual (%d):\n", dstSizes[2]); printf("GreenPlaneRle Actual (%d):\n", dstSizes[2]);
winpr_HexDump(planar->rlePlanes[2], dstSizes[2]); winpr_HexDump("codec.test", WLOG_DEBUG, planar->rlePlanes[2], dstSizes[2]);
return -1; return -1;
} }
/* Blue */ /* Blue */
dstSizes[3] = availableSize; dstSizes[3] = availableSize;
if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[3], width, height, pOutput, &dstSizes[3])) if (!freerdp_bitmap_planar_compress_plane_rle(planar->deltaPlanes[3], width, height, pOutput, &dstSizes[3]))
@ -3069,82 +3054,66 @@ int test_individual_planes_encoding_rle()
if (dstSizes[3] != sizeof(TEST_64X64_BLUE_PLANE_RLE)) if (dstSizes[3] != sizeof(TEST_64X64_BLUE_PLANE_RLE))
{ {
printf("BluePlaneRle unexpected size: actual: %d, expected: %d\n", printf("BluePlaneRle unexpected size: actual: %d, expected: %d\n",
dstSizes[1], (int) sizeof(TEST_64X64_BLUE_PLANE_RLE)); dstSizes[1], (int) sizeof(TEST_64X64_BLUE_PLANE_RLE));
return -1; return -1;
} }
compareSize = (dstSizes[3] > sizeof(TEST_64X64_BLUE_PLANE_RLE)) ? sizeof(TEST_64X64_BLUE_PLANE_RLE) : dstSizes[3]; compareSize = (dstSizes[3] > sizeof(TEST_64X64_BLUE_PLANE_RLE)) ? sizeof(TEST_64X64_BLUE_PLANE_RLE) : dstSizes[3];
if (memcmp(planar->rlePlanes[3], (BYTE*) TEST_64X64_BLUE_PLANE_RLE, compareSize) != 0) if (memcmp(planar->rlePlanes[3], (BYTE *) TEST_64X64_BLUE_PLANE_RLE, compareSize) != 0)
{ {
printf("BluePlaneRle doesn't match expected output\n"); printf("BluePlaneRle doesn't match expected output\n");
printf("BluePlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_BLUE_PLANE_RLE)); printf("BluePlaneRle Expected (%d):\n", (int) sizeof(TEST_64X64_BLUE_PLANE_RLE));
winpr_HexDump((BYTE*) TEST_64X64_BLUE_PLANE_RLE, (int) sizeof(TEST_64X64_BLUE_PLANE_RLE)); winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE *) TEST_64X64_BLUE_PLANE_RLE, (int) sizeof(TEST_64X64_BLUE_PLANE_RLE));
printf("BluePlaneRle Actual (%d):\n", dstSizes[3]); printf("BluePlaneRle Actual (%d):\n", dstSizes[3]);
winpr_HexDump(planar->rlePlanes[3], dstSizes[3]); winpr_HexDump("codec.test", WLOG_DEBUG, planar->rlePlanes[3], dstSizes[3]);
return -1; return -1;
} }
freerdp_bitmap_planar_context_free(planar); freerdp_bitmap_planar_context_free(planar);
return 0; return 0;
} }
int TestFreeRDPCodecPlanar(int argc, char* argv[]) int TestFreeRDPCodecPlanar(int argc, char *argv[])
{ {
int i, j; int i, j;
int dstSize; int dstSize;
UINT32 format; UINT32 format;
HCLRCONV clrconv; HCLRCONV clrconv;
DWORD planarFlags; DWORD planarFlags;
BYTE* srcBitmap32; BYTE *srcBitmap32;
BYTE* srcBitmap16; BYTE *srcBitmap16;
int width, height; int width, height;
BYTE* blackBitmap; BYTE *blackBitmap;
BYTE* whiteBitmap; BYTE *whiteBitmap;
BYTE* randomBitmap; BYTE *randomBitmap;
BYTE* compressedBitmap; BYTE *compressedBitmap;
BYTE* decompressedBitmap; BYTE *decompressedBitmap;
BITMAP_PLANAR_CONTEXT* planar; BITMAP_PLANAR_CONTEXT *planar;
planarFlags = PLANAR_FORMAT_HEADER_NA; planarFlags = PLANAR_FORMAT_HEADER_NA;
planarFlags |= PLANAR_FORMAT_HEADER_RLE; planarFlags |= PLANAR_FORMAT_HEADER_RLE;
planar = freerdp_bitmap_planar_context_new(planarFlags, 64, 64); planar = freerdp_bitmap_planar_context_new(planarFlags, 64, 64);
clrconv = freerdp_clrconv_new(0); clrconv = freerdp_clrconv_new(0);
srcBitmap16 = (BYTE*) TEST_RLE_UNCOMPRESSED_BITMAP_16BPP; srcBitmap16 = (BYTE *) TEST_RLE_UNCOMPRESSED_BITMAP_16BPP;
srcBitmap32 = freerdp_image_convert(srcBitmap16, NULL, 32, 32, 16, 32, clrconv); srcBitmap32 = freerdp_image_convert(srcBitmap16, NULL, 32, 32, 16, 32, clrconv);
format = PIXEL_FORMAT_ARGB32; format = PIXEL_FORMAT_ARGB32;
#if 0 #if 0
freerdp_bitmap_compress_planar(planar, srcBitmap32, format, 32, 32, 32 * 4, NULL, &dstSize); freerdp_bitmap_compress_planar(planar, srcBitmap32, format, 32, 32, 32 * 4, NULL, &dstSize);
freerdp_bitmap_planar_compress_plane_rle((BYTE *) TEST_RLE_SCANLINE_UNCOMPRESSED, 12, 1, NULL, &dstSize);
freerdp_bitmap_planar_compress_plane_rle((BYTE*) TEST_RLE_SCANLINE_UNCOMPRESSED, 12, 1, NULL, &dstSize); freerdp_bitmap_planar_delta_encode_plane((BYTE *) TEST_RDP6_SCANLINES_ABSOLUTE, 6, 3, NULL);
freerdp_bitmap_planar_compress_plane_rle((BYTE *) TEST_RDP6_SCANLINES_DELTA_2C_ENCODED_UNSIGNED, 6, 3, NULL, &dstSize);
freerdp_bitmap_planar_delta_encode_plane((BYTE*) TEST_RDP6_SCANLINES_ABSOLUTE, 6, 3, NULL);
freerdp_bitmap_planar_compress_plane_rle((BYTE*) TEST_RDP6_SCANLINES_DELTA_2C_ENCODED_UNSIGNED, 6, 3, NULL, &dstSize);
#endif #endif
#if 1 #if 1
for (i = 4; i < 64; i += 4) for (i = 4; i < 64; i += 4)
{ {
width = i; width = i;
height = i; height = i;
whiteBitmap = (BYTE *) malloc(width * height * 4);
whiteBitmap = (BYTE*) malloc(width * height * 4);
FillMemory(whiteBitmap, width * height * 4, 0xFF); FillMemory(whiteBitmap, width * height * 4, 0xFF);
fill_bitmap_alpha_channel(whiteBitmap, width, height, 0x00); fill_bitmap_alpha_channel(whiteBitmap, width, height, 0x00);
compressedBitmap = freerdp_bitmap_compress_planar(planar, whiteBitmap, format, width, height, width * 4, NULL, &dstSize); compressedBitmap = freerdp_bitmap_compress_planar(planar, whiteBitmap, format, width, height, width * 4, NULL, &dstSize);
decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3160,11 +3129,9 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
if (memcmp(decompressedBitmap, whiteBitmap, width * height * 4) != 0) if (memcmp(decompressedBitmap, whiteBitmap, width * height * 4) != 0)
{ {
printf("white bitmap\n"); printf("white bitmap\n");
winpr_HexDump(whiteBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, whiteBitmap, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
printf("error decompressed white bitmap corrupted: width: %d height: %d\n", width, height); printf("error decompressed white bitmap corrupted: width: %d height: %d\n", width, height);
return -1; return -1;
} }
@ -3177,14 +3144,11 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
{ {
width = i; width = i;
height = i; height = i;
blackBitmap = (BYTE *) malloc(width * height * 4);
blackBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(blackBitmap, width * height * 4); ZeroMemory(blackBitmap, width * height * 4);
fill_bitmap_alpha_channel(blackBitmap, width, height, 0x00); fill_bitmap_alpha_channel(blackBitmap, width, height, 0x00);
compressedBitmap = freerdp_bitmap_compress_planar(planar, blackBitmap, format, width, height, width * 4, NULL, &dstSize); compressedBitmap = freerdp_bitmap_compress_planar(planar, blackBitmap, format, width, height, width * 4, NULL, &dstSize);
decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3200,11 +3164,9 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
if (memcmp(decompressedBitmap, blackBitmap, width * height * 4) != 0) if (memcmp(decompressedBitmap, blackBitmap, width * height * 4) != 0)
{ {
printf("black bitmap\n"); printf("black bitmap\n");
winpr_HexDump(blackBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, blackBitmap, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
printf("error decompressed black bitmap corrupted: width: %d height: %d\n", width, height); printf("error decompressed black bitmap corrupted: width: %d height: %d\n", width, height);
return -1; return -1;
} }
@ -3217,19 +3179,16 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
{ {
width = i; width = i;
height = i; height = i;
randomBitmap = (BYTE *) malloc(width * height * 4);
randomBitmap = (BYTE*) malloc(width * height * 4);
for (j = 0; j < width * height * 4; j++) for (j = 0; j < width * height * 4; j++)
{ {
randomBitmap[j] = (BYTE) (simple_rand() % 256); randomBitmap[j] = (BYTE)(simple_rand() % 256);
} }
fill_bitmap_alpha_channel(randomBitmap, width, height, 0x00); fill_bitmap_alpha_channel(randomBitmap, width, height, 0x00);
compressedBitmap = freerdp_bitmap_compress_planar(planar, randomBitmap, format, width, height, width * 4, NULL, &dstSize); compressedBitmap = freerdp_bitmap_compress_planar(planar, randomBitmap, format, width, height, width * 4, NULL, &dstSize);
decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3245,11 +3204,9 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
if (memcmp(decompressedBitmap, randomBitmap, width * height * 4) != 0) if (memcmp(decompressedBitmap, randomBitmap, width * height * 4) != 0)
{ {
printf("random bitmap\n"); printf("random bitmap\n");
winpr_HexDump(randomBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, randomBitmap, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
printf("error decompressed random bitmap corrupted: width: %d height: %d\n", width, height); printf("error decompressed random bitmap corrupted: width: %d height: %d\n", width, height);
return -1; return -1;
} }
@ -3259,14 +3216,11 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
} }
/* Experimental Case 01 */ /* Experimental Case 01 */
width = 64; width = 64;
height = 64; height = 64;
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_01,
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_01, format, width, height, width * 4, NULL, &dstSize);
format, width, height, width * 4, NULL, &dstSize); decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3280,34 +3234,28 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
} }
fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF); fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF);
fill_bitmap_alpha_channel((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_01, width, height, 0xFF); fill_bitmap_alpha_channel((BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_01, width, height, 0xFF);
if (memcmp(decompressedBitmap, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_01, width * height * 4) != 0) if (memcmp(decompressedBitmap, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_01, width * height * 4) != 0)
{ {
#if 0 #if 0
printf("experimental bitmap 01\n"); printf("experimental bitmap 01\n");
winpr_HexDump((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_01, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_01, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
#endif #endif
printf("error: decompressed experimental bitmap 01 is corrupted\n"); printf("error: decompressed experimental bitmap 01 is corrupted\n");
return -1; return -1;
} }
free(compressedBitmap); free(compressedBitmap);
free(decompressedBitmap); free(decompressedBitmap);
/* Experimental Case 02 */ /* Experimental Case 02 */
width = 64; width = 64;
height = 64; height = 64;
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_02,
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_02, format, width, height, width * 4, NULL, &dstSize);
format, width, height, width * 4, NULL, &dstSize); decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3321,18 +3269,16 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
} }
fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF); fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF);
fill_bitmap_alpha_channel((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_02, width, height, 0xFF); fill_bitmap_alpha_channel((BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_02, width, height, 0xFF);
if (memcmp(decompressedBitmap, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_02, width * height * 4) != 0) if (memcmp(decompressedBitmap, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_02, width * height * 4) != 0)
{ {
#if 0 #if 0
printf("experimental bitmap 02\n"); printf("experimental bitmap 02\n");
winpr_HexDump((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_02, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_02, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
#endif #endif
printf("error: decompressed experimental bitmap 02 is corrupted\n"); printf("error: decompressed experimental bitmap 02 is corrupted\n");
return -1; return -1;
} }
@ -3347,14 +3293,11 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
} }
/* Experimental Case 03 */ /* Experimental Case 03 */
width = 64; width = 64;
height = 64; height = 64;
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_03,
compressedBitmap = freerdp_bitmap_compress_planar(planar, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_03, format, width, height, width * 4, NULL, &dstSize);
format, width, height, width * 4, NULL, &dstSize); decompressedBitmap = (BYTE *) malloc(width * height * 4);
decompressedBitmap = (BYTE*) malloc(width * height * 4);
ZeroMemory(decompressedBitmap, width * height * 4); ZeroMemory(decompressedBitmap, width * height * 4);
if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32)) if (!bitmap_decompress(compressedBitmap, decompressedBitmap, width, height, dstSize, 32, 32))
@ -3368,29 +3311,24 @@ int TestFreeRDPCodecPlanar(int argc, char* argv[])
} }
fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF); fill_bitmap_alpha_channel(decompressedBitmap, width, height, 0xFF);
fill_bitmap_alpha_channel((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_03, width, height, 0xFF); fill_bitmap_alpha_channel((BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_03, width, height, 0xFF);
if (memcmp(decompressedBitmap, (BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_03, width * height * 4) != 0) if (memcmp(decompressedBitmap, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_03, width * height * 4) != 0)
{ {
#if 0 #if 0
printf("experimental bitmap 03\n"); printf("experimental bitmap 03\n");
winpr_HexDump((BYTE*) TEST_RLE_BITMAP_EXPERIMENTAL_03, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, (BYTE *) TEST_RLE_BITMAP_EXPERIMENTAL_03, width * height * 4);
printf("decompressed bitmap\n"); printf("decompressed bitmap\n");
winpr_HexDump(decompressedBitmap, width * height * 4); winpr_HexDump("codec.test", WLOG_DEBUG, decompressedBitmap, width * height * 4);
#endif #endif
printf("error: decompressed experimental bitmap 03 is corrupted\n"); printf("error: decompressed experimental bitmap 03 is corrupted\n");
return -1; return -1;
} }
free(compressedBitmap); free(compressedBitmap);
free(decompressedBitmap); free(decompressedBitmap);
freerdp_clrconv_free(clrconv); freerdp_clrconv_free(clrconv);
_aligned_free(srcBitmap32); _aligned_free(srcBitmap32);
freerdp_bitmap_planar_context_free(planar); freerdp_bitmap_planar_context_free(planar);
return 0; return 0;
} }

View File

@ -23,9 +23,11 @@
#include "bulk.h" #include "bulk.h"
#define TAG "com.freerdp.core"
//#define WITH_BULK_DEBUG 1 //#define WITH_BULK_DEBUG 1
const char* bulk_get_compression_flags_string(UINT32 flags) const char *bulk_get_compression_flags_string(UINT32 flags)
{ {
flags &= BULK_COMPRESSION_FLAGS_MASK; flags &= BULK_COMPRESSION_FLAGS_MASK;
@ -49,38 +51,32 @@ const char* bulk_get_compression_flags_string(UINT32 flags)
return "PACKET_UNKNOWN"; return "PACKET_UNKNOWN";
} }
UINT32 bulk_compression_level(rdpBulk* bulk) UINT32 bulk_compression_level(rdpBulk *bulk)
{ {
rdpSettings* settings = bulk->context->settings; rdpSettings *settings = bulk->context->settings;
bulk->CompressionLevel = (settings->CompressionLevel >= PACKET_COMPR_TYPE_RDP61) ? bulk->CompressionLevel = (settings->CompressionLevel >= PACKET_COMPR_TYPE_RDP61) ?
PACKET_COMPR_TYPE_RDP61 : settings->CompressionLevel; PACKET_COMPR_TYPE_RDP61 : settings->CompressionLevel;
return bulk->CompressionLevel; return bulk->CompressionLevel;
} }
UINT32 bulk_compression_max_size(rdpBulk* bulk) UINT32 bulk_compression_max_size(rdpBulk *bulk)
{ {
bulk_compression_level(bulk); bulk_compression_level(bulk);
bulk->CompressionMaxSize = (bulk->CompressionLevel < PACKET_COMPR_TYPE_64K) ? 8192 : 65536; bulk->CompressionMaxSize = (bulk->CompressionLevel < PACKET_COMPR_TYPE_64K) ? 8192 : 65536;
return bulk->CompressionMaxSize; return bulk->CompressionMaxSize;
} }
int bulk_compress_validate(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstData, UINT32* pDstSize, UINT32* pFlags) int bulk_compress_validate(rdpBulk *bulk, BYTE *pSrcData, UINT32 SrcSize, BYTE **ppDstData, UINT32 *pDstSize, UINT32 *pFlags)
{ {
int status; int status;
BYTE* _pSrcData = NULL; BYTE *_pSrcData = NULL;
BYTE* _pDstData = NULL; BYTE *_pDstData = NULL;
UINT32 _SrcSize = 0; UINT32 _SrcSize = 0;
UINT32 _DstSize = 0; UINT32 _DstSize = 0;
UINT32 _Flags = 0; UINT32 _Flags = 0;
_pSrcData = *ppDstData; _pSrcData = *ppDstData;
_SrcSize = *pDstSize; _SrcSize = *pDstSize;
_Flags = *pFlags | bulk->CompressionLevel; _Flags = *pFlags | bulk->CompressionLevel;
status = bulk_decompress(bulk, _pSrcData, _SrcSize, &_pDstData, &_DstSize, _Flags); status = bulk_decompress(bulk, _pSrcData, _SrcSize, &_pDstData, &_DstSize, _Flags);
if (status < 0) if (status < 0)
@ -98,32 +94,27 @@ int bulk_compress_validate(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE**
if (memcmp(_pDstData, pSrcData, SrcSize) != 0) if (memcmp(_pDstData, pSrcData, SrcSize) != 0)
{ {
DEBUG_MSG("compression/decompression input/output mismatch! flags: 0x%04X\n", _Flags); DEBUG_MSG("compression/decompression input/output mismatch! flags: 0x%04X\n", _Flags);
#if 1 #if 1
DEBUG_MSG("Actual:\n"); DEBUG_MSG("Actual:\n");
winpr_HexDump(_pDstData, SrcSize); winpr_HexDump(TAG, WLOG_DEBUG, _pDstData, SrcSize);
DEBUG_MSG("Expected:\n"); DEBUG_MSG("Expected:\n");
winpr_HexDump(pSrcData, SrcSize); winpr_HexDump(TAG, WLOG_DEBUG, pSrcData, SrcSize);
#endif #endif
return -1; return -1;
} }
return status; return status;
} }
int bulk_decompress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstData, UINT32* pDstSize, UINT32 flags) int bulk_decompress(rdpBulk *bulk, BYTE *pSrcData, UINT32 SrcSize, BYTE **ppDstData, UINT32 *pDstSize, UINT32 flags)
{ {
UINT32 type; UINT32 type;
int status = -1; int status = -1;
rdpMetrics* metrics; rdpMetrics *metrics;
UINT32 CompressedBytes; UINT32 CompressedBytes;
UINT32 UncompressedBytes; UINT32 UncompressedBytes;
double CompressionRatio; double CompressionRatio;
metrics = bulk->context->metrics; metrics = bulk->context->metrics;
bulk_compression_max_size(bulk); bulk_compression_max_size(bulk);
type = flags & BULK_COMPRESSION_TYPE_MASK; type = flags & BULK_COMPRESSION_TYPE_MASK;
@ -135,20 +126,16 @@ int bulk_decompress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstD
mppc_set_compression_level(bulk->mppcRecv, 0); mppc_set_compression_level(bulk->mppcRecv, 0);
status = mppc_decompress(bulk->mppcRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags); status = mppc_decompress(bulk->mppcRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags);
break; break;
case PACKET_COMPR_TYPE_64K: case PACKET_COMPR_TYPE_64K:
mppc_set_compression_level(bulk->mppcRecv, 1); mppc_set_compression_level(bulk->mppcRecv, 1);
status = mppc_decompress(bulk->mppcRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags); status = mppc_decompress(bulk->mppcRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags);
break; break;
case PACKET_COMPR_TYPE_RDP6: case PACKET_COMPR_TYPE_RDP6:
status = ncrush_decompress(bulk->ncrushRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags); status = ncrush_decompress(bulk->ncrushRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags);
break; break;
case PACKET_COMPR_TYPE_RDP61: case PACKET_COMPR_TYPE_RDP61:
status = xcrush_decompress(bulk->xcrushRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags); status = xcrush_decompress(bulk->xcrushRecv, pSrcData, SrcSize, ppDstData, pDstSize, flags);
break; break;
case PACKET_COMPR_TYPE_RDP8: case PACKET_COMPR_TYPE_RDP8:
status = -1; status = -1;
break; break;
@ -165,35 +152,32 @@ int bulk_decompress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstD
{ {
CompressedBytes = SrcSize; CompressedBytes = SrcSize;
UncompressedBytes = *pDstSize; UncompressedBytes = *pDstSize;
CompressionRatio = metrics_write_bytes(metrics, UncompressedBytes, CompressedBytes); CompressionRatio = metrics_write_bytes(metrics, UncompressedBytes, CompressedBytes);
#ifdef WITH_BULK_DEBUG #ifdef WITH_BULK_DEBUG
{ {
DEBUG_MSG("Decompress Type: %d Flags: %s (0x%04X) Compression Ratio: %f (%d / %d), Total: %f (%u / %u)\n", DEBUG_MSG("Decompress Type: %d Flags: %s (0x%04X) Compression Ratio: %f (%d / %d), Total: %f (%u / %u)\n",
type, bulk_get_compression_flags_string(flags), flags, type, bulk_get_compression_flags_string(flags), flags,
CompressionRatio, CompressedBytes, UncompressedBytes, CompressionRatio, CompressedBytes, UncompressedBytes,
metrics->TotalCompressionRatio, (UINT32) metrics->TotalCompressedBytes, metrics->TotalCompressionRatio, (UINT32) metrics->TotalCompressedBytes,
(UINT32) metrics->TotalUncompressedBytes); (UINT32) metrics->TotalUncompressedBytes);
} }
#endif #endif
} }
else else
{ {
DEBUG_WARN( "Decompression failure!\n"); DEBUG_WARN("Decompression failure!\n");
} }
return status; return status;
} }
int bulk_compress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstData, UINT32* pDstSize, UINT32* pFlags) int bulk_compress(rdpBulk *bulk, BYTE *pSrcData, UINT32 SrcSize, BYTE **ppDstData, UINT32 *pDstSize, UINT32 *pFlags)
{ {
int status = -1; int status = -1;
rdpMetrics* metrics; rdpMetrics *metrics;
UINT32 CompressedBytes; UINT32 CompressedBytes;
UINT32 UncompressedBytes; UINT32 UncompressedBytes;
double CompressionRatio; double CompressionRatio;
metrics = bulk->context->metrics; metrics = bulk->context->metrics;
if ((SrcSize <= 50) || (SrcSize >= 16384)) if ((SrcSize <= 50) || (SrcSize >= 16384))
@ -205,10 +189,9 @@ int bulk_compress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstDat
*ppDstData = bulk->OutputBuffer; *ppDstData = bulk->OutputBuffer;
*pDstSize = sizeof(bulk->OutputBuffer); *pDstSize = sizeof(bulk->OutputBuffer);
bulk_compression_level(bulk); bulk_compression_level(bulk);
bulk_compression_max_size(bulk); bulk_compression_max_size(bulk);
if ((bulk->CompressionLevel == PACKET_COMPR_TYPE_8K) || if ((bulk->CompressionLevel == PACKET_COMPR_TYPE_8K) ||
(bulk->CompressionLevel == PACKET_COMPR_TYPE_64K)) (bulk->CompressionLevel == PACKET_COMPR_TYPE_64K))
{ {
@ -232,78 +215,67 @@ int bulk_compress(rdpBulk* bulk, BYTE* pSrcData, UINT32 SrcSize, BYTE** ppDstDat
{ {
CompressedBytes = *pDstSize; CompressedBytes = *pDstSize;
UncompressedBytes = SrcSize; UncompressedBytes = SrcSize;
CompressionRatio = metrics_write_bytes(metrics, UncompressedBytes, CompressedBytes); CompressionRatio = metrics_write_bytes(metrics, UncompressedBytes, CompressedBytes);
#ifdef WITH_BULK_DEBUG #ifdef WITH_BULK_DEBUG
{ {
DEBUG_MSG("Compress Type: %d Flags: %s (0x%04X) Compression Ratio: %f (%d / %d), Total: %f (%u / %u)\n", DEBUG_MSG("Compress Type: %d Flags: %s (0x%04X) Compression Ratio: %f (%d / %d), Total: %f (%u / %u)\n",
bulk->CompressionLevel, bulk_get_compression_flags_string(*pFlags), *pFlags, bulk->CompressionLevel, bulk_get_compression_flags_string(*pFlags), *pFlags,
CompressionRatio, CompressedBytes, UncompressedBytes, CompressionRatio, CompressedBytes, UncompressedBytes,
metrics->TotalCompressionRatio, (UINT32) metrics->TotalCompressedBytes, metrics->TotalCompressionRatio, (UINT32) metrics->TotalCompressedBytes,
(UINT32) metrics->TotalUncompressedBytes); (UINT32) metrics->TotalUncompressedBytes);
} }
#endif #endif
} }
#if 0 #if 0
if (bulk_compress_validate(bulk, pSrcData, SrcSize, ppDstData, pDstSize, pFlags) < 0) if (bulk_compress_validate(bulk, pSrcData, SrcSize, ppDstData, pDstSize, pFlags) < 0)
status = -1; status = -1;
#endif
#endif
return status; return status;
} }
void bulk_reset(rdpBulk* bulk) void bulk_reset(rdpBulk *bulk)
{ {
mppc_context_reset(bulk->mppcSend, FALSE); mppc_context_reset(bulk->mppcSend, FALSE);
mppc_context_reset(bulk->mppcRecv, FALSE); mppc_context_reset(bulk->mppcRecv, FALSE);
ncrush_context_reset(bulk->ncrushRecv, FALSE); ncrush_context_reset(bulk->ncrushRecv, FALSE);
ncrush_context_reset(bulk->ncrushSend, FALSE); ncrush_context_reset(bulk->ncrushSend, FALSE);
xcrush_context_reset(bulk->xcrushRecv, FALSE); xcrush_context_reset(bulk->xcrushRecv, FALSE);
xcrush_context_reset(bulk->xcrushSend, FALSE); xcrush_context_reset(bulk->xcrushSend, FALSE);
} }
rdpBulk* bulk_new(rdpContext* context) rdpBulk *bulk_new(rdpContext *context)
{ {
rdpBulk* bulk; rdpBulk *bulk;
bulk = (rdpBulk *) calloc(1, sizeof(rdpBulk));
bulk = (rdpBulk*) calloc(1, sizeof(rdpBulk));
if (bulk) if (bulk)
{ {
bulk->context = context; bulk->context = context;
bulk->mppcSend = mppc_context_new(1, TRUE); bulk->mppcSend = mppc_context_new(1, TRUE);
bulk->mppcRecv = mppc_context_new(1, FALSE); bulk->mppcRecv = mppc_context_new(1, FALSE);
bulk->ncrushRecv = ncrush_context_new(FALSE); bulk->ncrushRecv = ncrush_context_new(FALSE);
bulk->ncrushSend = ncrush_context_new(TRUE); bulk->ncrushSend = ncrush_context_new(TRUE);
bulk->xcrushRecv = xcrush_context_new(FALSE); bulk->xcrushRecv = xcrush_context_new(FALSE);
bulk->xcrushSend = xcrush_context_new(TRUE); bulk->xcrushSend = xcrush_context_new(TRUE);
bulk->CompressionLevel = context->settings->CompressionLevel; bulk->CompressionLevel = context->settings->CompressionLevel;
} }
return bulk; return bulk;
} }
void bulk_free(rdpBulk* bulk) void bulk_free(rdpBulk *bulk)
{ {
if (!bulk) if (!bulk)
return; return;
mppc_context_free(bulk->mppcSend); mppc_context_free(bulk->mppcSend);
mppc_context_free(bulk->mppcRecv); mppc_context_free(bulk->mppcRecv);
ncrush_context_free(bulk->ncrushRecv); ncrush_context_free(bulk->ncrushRecv);
ncrush_context_free(bulk->ncrushSend); ncrush_context_free(bulk->ncrushSend);
xcrush_context_free(bulk->xcrushRecv); xcrush_context_free(bulk->xcrushRecv);
xcrush_context_free(bulk->xcrushSend); xcrush_context_free(bulk->xcrushSend);
free(bulk); free(bulk);
} }

View File

@ -33,6 +33,8 @@
#include "certificate.h" #include "certificate.h"
#define TAG "com.freerdp.core"
/** /**
* *
* X.509 Certificate Structure * X.509 Certificate Structure
@ -121,7 +123,8 @@
* *
*/ */
static const char *certificate_read_errors[] = { static const char *certificate_read_errors[] =
{
"Certificate tag", "Certificate tag",
"TBSCertificate", "TBSCertificate",
"Explicit Contextual Tag [0]", "Explicit Contextual Tag [0]",
@ -150,84 +153,100 @@ static const char *certificate_read_errors[] = {
* @param cert X.509 certificate * @param cert X.509 certificate
*/ */
BOOL certificate_read_x509_certificate(rdpCertBlob* cert, rdpCertInfo* info) BOOL certificate_read_x509_certificate(rdpCertBlob *cert, rdpCertInfo *info)
{ {
wStream* s; wStream *s;
int length; int length;
BYTE padding; BYTE padding;
UINT32 version; UINT32 version;
int modulus_length; int modulus_length;
int exponent_length; int exponent_length;
int error = 0; int error = 0;
s = Stream_New(cert->data, cert->length); s = Stream_New(cert->data, cert->length);
if (!s) if (!s)
return FALSE; return FALSE;
info->Modulus = 0; info->Modulus = 0;
if (!ber_read_sequence_tag(s, &length)) /* Certificate (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length)) /* Certificate (SEQUENCE) */
goto error1; goto error1;
error++; error++;
if (!ber_read_sequence_tag(s, &length)) /* TBSCertificate (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length)) /* TBSCertificate (SEQUENCE) */
goto error1; goto error1;
error++; error++;
if (!ber_read_contextual_tag(s, 0, &length, TRUE)) /* Explicit Contextual Tag [0] */ if (!ber_read_contextual_tag(s, 0, &length, TRUE)) /* Explicit Contextual Tag [0] */
goto error1; goto error1;
error++; error++;
if (!ber_read_integer(s, &version)) /* version (INTEGER) */ if (!ber_read_integer(s, &version)) /* version (INTEGER) */
goto error1; goto error1;
error++; error++;
version++; version++;
/* serialNumber */ /* serialNumber */
if (!ber_read_integer(s, NULL)) /* CertificateSerialNumber (INTEGER) */ if (!ber_read_integer(s, NULL)) /* CertificateSerialNumber (INTEGER) */
goto error1; goto error1;
error++; error++;
/* signature */ /* signature */
if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* AlgorithmIdentifier (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* AlgorithmIdentifier (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* issuer */ /* issuer */
if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Name (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Name (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* validity */ /* validity */
if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Validity (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Validity (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* subject */ /* subject */
if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Name (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* Name (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* subjectPublicKeyInfo */ /* subjectPublicKeyInfo */
if (!ber_read_sequence_tag(s, &length)) /* SubjectPublicKeyInfo (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length)) /* SubjectPublicKeyInfo (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* subjectPublicKeyInfo::AlgorithmIdentifier */ /* subjectPublicKeyInfo::AlgorithmIdentifier */
if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* AlgorithmIdentifier (SEQUENCE) */ if (!ber_read_sequence_tag(s, &length) || !Stream_SafeSeek(s, length)) /* AlgorithmIdentifier (SEQUENCE) */
goto error1; goto error1;
error++; error++;
/* subjectPublicKeyInfo::subjectPublicKey */ /* subjectPublicKeyInfo::subjectPublicKey */
if (!ber_read_bit_string(s, &length, &padding)) /* BIT_STRING */ if (!ber_read_bit_string(s, &length, &padding)) /* BIT_STRING */
goto error1; goto error1;
error++; error++;
/* RSAPublicKey (SEQUENCE) */ /* RSAPublicKey (SEQUENCE) */
if (!ber_read_sequence_tag(s, &length)) /* SEQUENCE */ if (!ber_read_sequence_tag(s, &length)) /* SEQUENCE */
goto error1; goto error1;
error++; error++;
if (!ber_read_integer_length(s, &modulus_length)) /* modulus (INTEGER) */ if (!ber_read_integer_length(s, &modulus_length)) /* modulus (INTEGER) */
goto error1; goto error1;
error++; error++;
/* skip zero padding, if any */ /* skip zero padding, if any */
@ -254,9 +273,11 @@ BOOL certificate_read_x509_certificate(rdpCertBlob* cert, rdpCertInfo* info)
goto error1; goto error1;
info->ModulusLength = modulus_length; info->ModulusLength = modulus_length;
info->Modulus = (BYTE*) malloc(info->ModulusLength); info->Modulus = (BYTE *) malloc(info->ModulusLength);
if (!info->Modulus) if (!info->Modulus)
goto error1; goto error1;
Stream_Read(s, info->Modulus, info->ModulusLength); Stream_Read(s, info->Modulus, info->ModulusLength);
error++; error++;
@ -271,15 +292,13 @@ BOOL certificate_read_x509_certificate(rdpCertBlob* cert, rdpCertInfo* info)
Stream_Read(s, &info->exponent[4 - exponent_length], exponent_length); Stream_Read(s, &info->exponent[4 - exponent_length], exponent_length);
crypto_reverse(info->Modulus, info->ModulusLength); crypto_reverse(info->Modulus, info->ModulusLength);
crypto_reverse(info->exponent, 4); crypto_reverse(info->exponent, 4);
Stream_Free(s, FALSE); Stream_Free(s, FALSE);
return TRUE; return TRUE;
error2: error2:
free(info->Modulus); free(info->Modulus);
info->Modulus = 0; info->Modulus = 0;
error1: error1:
DEBUG_WARN( "error reading when reading certificate: part=%s error=%d\n", certificate_read_errors[error], error); DEBUG_WARN("error reading when reading certificate: part=%s error=%d\n", certificate_read_errors[error], error);
Stream_Free(s, FALSE); Stream_Free(s, FALSE);
return FALSE; return FALSE;
} }
@ -290,21 +309,23 @@ error1:
* @return new X.509 certificate chain * @return new X.509 certificate chain
*/ */
rdpX509CertChain* certificate_new_x509_certificate_chain(UINT32 count) rdpX509CertChain *certificate_new_x509_certificate_chain(UINT32 count)
{ {
rdpX509CertChain* x509_cert_chain; rdpX509CertChain *x509_cert_chain;
x509_cert_chain = (rdpX509CertChain *)malloc(sizeof(rdpX509CertChain)); x509_cert_chain = (rdpX509CertChain *)malloc(sizeof(rdpX509CertChain));
if (!x509_cert_chain) if (!x509_cert_chain)
return NULL; return NULL;
x509_cert_chain->count = count; x509_cert_chain->count = count;
x509_cert_chain->array = (rdpCertBlob *)calloc(count, sizeof(rdpCertBlob)); x509_cert_chain->array = (rdpCertBlob *)calloc(count, sizeof(rdpCertBlob));
if (!x509_cert_chain->array) if (!x509_cert_chain->array)
{ {
free(x509_cert_chain); free(x509_cert_chain);
return NULL; return NULL;
} }
return x509_cert_chain; return x509_cert_chain;
} }
@ -313,7 +334,7 @@ rdpX509CertChain* certificate_new_x509_certificate_chain(UINT32 count)
* @param x509_cert_chain X.509 certificate chain to be freed * @param x509_cert_chain X.509 certificate chain to be freed
*/ */
void certificate_free_x509_certificate_chain(rdpX509CertChain* x509_cert_chain) void certificate_free_x509_certificate_chain(rdpX509CertChain *x509_cert_chain)
{ {
int i; int i;
@ -330,7 +351,7 @@ void certificate_free_x509_certificate_chain(rdpX509CertChain* x509_cert_chain)
free(x509_cert_chain); free(x509_cert_chain);
} }
static BOOL certificate_process_server_public_key(rdpCertificate* certificate, wStream* s, UINT32 length) static BOOL certificate_process_server_public_key(rdpCertificate *certificate, wStream *s, UINT32 length)
{ {
BYTE magic[4]; BYTE magic[4];
UINT32 keylen; UINT32 keylen;
@ -340,11 +361,12 @@ static BOOL certificate_process_server_public_key(rdpCertificate* certificate, w
if (Stream_GetRemainingLength(s) < 20) if (Stream_GetRemainingLength(s) < 20)
return FALSE; return FALSE;
Stream_Read(s, magic, 4); Stream_Read(s, magic, 4);
if (memcmp(magic, "RSA1", 4) != 0) if (memcmp(magic, "RSA1", 4) != 0)
{ {
DEBUG_WARN( "%s: magic error\n", __FUNCTION__); DEBUG_WARN("%s: magic error\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -356,32 +378,34 @@ static BOOL certificate_process_server_public_key(rdpCertificate* certificate, w
if (Stream_GetRemainingLength(s) < modlen + 8) // count padding if (Stream_GetRemainingLength(s) < modlen + 8) // count padding
return FALSE; return FALSE;
certificate->cert_info.ModulusLength = modlen; certificate->cert_info.ModulusLength = modlen;
certificate->cert_info.Modulus = malloc(certificate->cert_info.ModulusLength); certificate->cert_info.Modulus = malloc(certificate->cert_info.ModulusLength);
if (!certificate->cert_info.Modulus) if (!certificate->cert_info.Modulus)
return FALSE; return FALSE;
Stream_Read(s, certificate->cert_info.Modulus, certificate->cert_info.ModulusLength); Stream_Read(s, certificate->cert_info.Modulus, certificate->cert_info.ModulusLength);
/* 8 bytes of zero padding */ /* 8 bytes of zero padding */
Stream_Seek(s, 8); Stream_Seek(s, 8);
return TRUE; return TRUE;
} }
static BOOL certificate_process_server_public_signature(rdpCertificate* certificate, static BOOL certificate_process_server_public_signature(rdpCertificate *certificate,
const BYTE* sigdata, int sigdatalen, wStream* s, UINT32 siglen) const BYTE *sigdata, int sigdatalen, wStream *s, UINT32 siglen)
{ {
int i, sum; int i, sum;
CryptoMd5 md5ctx; CryptoMd5 md5ctx;
BYTE sig[TSSK_KEY_LENGTH]; BYTE sig[TSSK_KEY_LENGTH];
BYTE encsig[TSSK_KEY_LENGTH + 8]; BYTE encsig[TSSK_KEY_LENGTH + 8];
BYTE md5hash[CRYPTO_MD5_DIGEST_LENGTH]; BYTE md5hash[CRYPTO_MD5_DIGEST_LENGTH];
md5ctx = crypto_md5_init(); md5ctx = crypto_md5_init();
if (!md5ctx) if (!md5ctx)
return FALSE; return FALSE;
crypto_md5_update(md5ctx, sigdata, sigdatalen); crypto_md5_update(md5ctx, sigdata, sigdatalen);
crypto_md5_final(md5ctx, md5hash); crypto_md5_final(md5ctx, md5hash);
Stream_Read(s, encsig, siglen); Stream_Read(s, encsig, siglen);
/* Last 8 bytes shall be all zero. */ /* Last 8 bytes shall be all zero. */
@ -391,19 +415,18 @@ static BOOL certificate_process_server_public_signature(rdpCertificate* certific
if (sum != 0) if (sum != 0)
{ {
DEBUG_WARN( "%s: invalid signature\n", __FUNCTION__); DEBUG_WARN("%s: invalid signature\n", __FUNCTION__);
//return FALSE; //return FALSE;
} }
siglen -= 8; siglen -= 8;
// TODO: check the result of decrypt // TODO: check the result of decrypt
crypto_rsa_public_decrypt(encsig, siglen, TSSK_KEY_LENGTH, tssk_modulus, tssk_exponent, sig); crypto_rsa_public_decrypt(encsig, siglen, TSSK_KEY_LENGTH, tssk_modulus, tssk_exponent, sig);
/* Verify signature. */ /* Verify signature. */
if (memcmp(md5hash, sig, sizeof(md5hash)) != 0) if (memcmp(md5hash, sig, sizeof(md5hash)) != 0)
{ {
DEBUG_WARN( "%s: invalid signature\n", __FUNCTION__); DEBUG_WARN("%s: invalid signature\n", __FUNCTION__);
//return FALSE; //return FALSE;
} }
@ -419,7 +442,7 @@ static BOOL certificate_process_server_public_signature(rdpCertificate* certific
if (sig[16] != 0x00 || sum != 0xFF * (62 - 17) || sig[62] != 0x01) if (sig[16] != 0x00 || sum != 0xFF * (62 - 17) || sig[62] != 0x01)
{ {
DEBUG_WARN( "%s: invalid signature\n", __FUNCTION__); DEBUG_WARN("%s: invalid signature\n", __FUNCTION__);
//return FALSE; //return FALSE;
} }
@ -432,7 +455,7 @@ static BOOL certificate_process_server_public_signature(rdpCertificate* certific
* @param s stream * @param s stream
*/ */
BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate, wStream* s) BOOL certificate_read_server_proprietary_certificate(rdpCertificate *certificate, wStream *s)
{ {
UINT32 dwSigAlgId; UINT32 dwSigAlgId;
UINT32 dwKeyAlgId; UINT32 dwKeyAlgId;
@ -440,7 +463,7 @@ BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate
UINT32 wPublicKeyBlobLen; UINT32 wPublicKeyBlobLen;
UINT32 wSignatureBlobType; UINT32 wSignatureBlobType;
UINT32 wSignatureBlobLen; UINT32 wSignatureBlobLen;
BYTE* sigdata; BYTE *sigdata;
int sigdatalen; int sigdatalen;
if (Stream_GetRemainingLength(s) < 12) if (Stream_GetRemainingLength(s) < 12)
@ -453,8 +476,8 @@ BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate
if (!(dwSigAlgId == SIGNATURE_ALG_RSA && dwKeyAlgId == KEY_EXCHANGE_ALG_RSA)) if (!(dwSigAlgId == SIGNATURE_ALG_RSA && dwKeyAlgId == KEY_EXCHANGE_ALG_RSA))
{ {
DEBUG_WARN( "%s: unsupported signature or key algorithm, dwSigAlgId=%d dwKeyAlgId=%d\n", DEBUG_WARN("%s: unsupported signature or key algorithm, dwSigAlgId=%d dwKeyAlgId=%d\n",
__FUNCTION__, dwSigAlgId, dwKeyAlgId); __FUNCTION__, dwSigAlgId, dwKeyAlgId);
return FALSE; return FALSE;
} }
@ -462,17 +485,18 @@ BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate
if (wPublicKeyBlobType != BB_RSA_KEY_BLOB) if (wPublicKeyBlobType != BB_RSA_KEY_BLOB)
{ {
DEBUG_WARN( "%s: unsupported public key blob type %d\n", __FUNCTION__, wPublicKeyBlobType); DEBUG_WARN("%s: unsupported public key blob type %d\n", __FUNCTION__, wPublicKeyBlobType);
return FALSE; return FALSE;
} }
Stream_Read_UINT16(s, wPublicKeyBlobLen); Stream_Read_UINT16(s, wPublicKeyBlobLen);
if (Stream_GetRemainingLength(s) < wPublicKeyBlobLen) if (Stream_GetRemainingLength(s) < wPublicKeyBlobLen)
return FALSE; return FALSE;
if (!certificate_process_server_public_key(certificate, s, wPublicKeyBlobLen)) if (!certificate_process_server_public_key(certificate, s, wPublicKeyBlobLen))
{ {
DEBUG_WARN( "%s: error in server public key\n", __FUNCTION__); DEBUG_WARN("%s: error in server public key\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -484,26 +508,27 @@ BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate
if (wSignatureBlobType != BB_RSA_SIGNATURE_BLOB) if (wSignatureBlobType != BB_RSA_SIGNATURE_BLOB)
{ {
DEBUG_WARN( "%s: unsupported blob signature %d\n", __FUNCTION__, wSignatureBlobType); DEBUG_WARN("%s: unsupported blob signature %d\n", __FUNCTION__, wSignatureBlobType);
return FALSE; return FALSE;
} }
Stream_Read_UINT16(s, wSignatureBlobLen); Stream_Read_UINT16(s, wSignatureBlobLen);
if (Stream_GetRemainingLength(s) < wSignatureBlobLen) if (Stream_GetRemainingLength(s) < wSignatureBlobLen)
{ {
DEBUG_WARN( "%s: not enought bytes for signature(len=%d)\n", __FUNCTION__, wSignatureBlobLen); DEBUG_WARN("%s: not enought bytes for signature(len=%d)\n", __FUNCTION__, wSignatureBlobLen);
return FALSE; return FALSE;
} }
if (wSignatureBlobLen != 72) if (wSignatureBlobLen != 72)
{ {
DEBUG_WARN( "%s: invalid signature length (got %d, expected %d)\n", __FUNCTION__, wSignatureBlobLen, 64); DEBUG_WARN("%s: invalid signature length (got %d, expected %d)\n", __FUNCTION__, wSignatureBlobLen, 64);
return FALSE; return FALSE;
} }
if (!certificate_process_server_public_signature(certificate, sigdata, sigdatalen, s, wSignatureBlobLen)) if (!certificate_process_server_public_signature(certificate, sigdata, sigdatalen, s, wSignatureBlobLen))
{ {
DEBUG_WARN( "%s: unable to parse server public signature\n", __FUNCTION__); DEBUG_WARN("%s: unable to parse server public signature\n", __FUNCTION__);
return FALSE; return FALSE;
} }
@ -516,20 +541,20 @@ BOOL certificate_read_server_proprietary_certificate(rdpCertificate* certificate
* @param s stream * @param s stream
*/ */
BOOL certificate_read_server_x509_certificate_chain(rdpCertificate* certificate, wStream* s) BOOL certificate_read_server_x509_certificate_chain(rdpCertificate *certificate, wStream *s)
{ {
int i; int i;
UINT32 certLength; UINT32 certLength;
UINT32 numCertBlobs; UINT32 numCertBlobs;
BOOL ret; BOOL ret;
DEBUG_CERTIFICATE("Server X.509 Certificate Chain"); DEBUG_CERTIFICATE("Server X.509 Certificate Chain");
if (Stream_GetRemainingLength(s) < 4) if (Stream_GetRemainingLength(s) < 4)
return FALSE; return FALSE;
Stream_Read_UINT32(s, numCertBlobs); /* numCertBlobs */
Stream_Read_UINT32(s, numCertBlobs); /* numCertBlobs */
certificate->x509_cert_chain = certificate_new_x509_certificate_chain(numCertBlobs); certificate->x509_cert_chain = certificate_new_x509_certificate_chain(numCertBlobs);
if (!certificate->x509_cert_chain) if (!certificate->x509_cert_chain)
return FALSE; return FALSE;
@ -544,10 +569,11 @@ BOOL certificate_read_server_x509_certificate_chain(rdpCertificate* certificate,
return FALSE; return FALSE;
DEBUG_CERTIFICATE("\nX.509 Certificate #%d, length:%d", i + 1, certLength); DEBUG_CERTIFICATE("\nX.509 Certificate #%d, length:%d", i + 1, certLength);
certificate->x509_cert_chain->array[i].data = (BYTE *) malloc(certLength);
certificate->x509_cert_chain->array[i].data = (BYTE*) malloc(certLength);
if (!certificate->x509_cert_chain->array[i].data) if (!certificate->x509_cert_chain->array[i].data)
return FALSE; return FALSE;
Stream_Read(s, certificate->x509_cert_chain->array[i].data, certLength); Stream_Read(s, certificate->x509_cert_chain->array[i].data, certLength);
certificate->x509_cert_chain->array[i].length = certLength; certificate->x509_cert_chain->array[i].length = certLength;
@ -557,19 +583,24 @@ BOOL certificate_read_server_x509_certificate_chain(rdpCertificate* certificate,
DEBUG_CERTIFICATE("License Server Certificate"); DEBUG_CERTIFICATE("License Server Certificate");
ret = certificate_read_x509_certificate(&certificate->x509_cert_chain->array[i], &cert_info); ret = certificate_read_x509_certificate(&certificate->x509_cert_chain->array[i], &cert_info);
DEBUG_LICENSE("modulus length:%d", (int) cert_info.ModulusLength); DEBUG_LICENSE("modulus length:%d", (int) cert_info.ModulusLength);
if (cert_info.Modulus) if (cert_info.Modulus)
free(cert_info.Modulus); free(cert_info.Modulus);
if (!ret) {
DEBUG_WARN( "failed to read License Server, content follows:\n"); if (!ret)
winpr_HexDump(certificate->x509_cert_chain->array[i].data, certificate->x509_cert_chain->array[i].length); {
DEBUG_WARN("failed to read License Server, content follows:\n");
winpr_HexDump(TAG, WLOG_ERROR, certificate->x509_cert_chain->array[i].data, certificate->x509_cert_chain->array[i].length);
return FALSE; return FALSE;
} }
} }
else if (numCertBlobs - i == 1) else if (numCertBlobs - i == 1)
{ {
DEBUG_CERTIFICATE("Terminal Server Certificate"); DEBUG_CERTIFICATE("Terminal Server Certificate");
if (!certificate_read_x509_certificate(&certificate->x509_cert_chain->array[i], &certificate->cert_info)) if (!certificate_read_x509_certificate(&certificate->x509_cert_chain->array[i], &certificate->cert_info))
return FALSE; return FALSE;
DEBUG_CERTIFICATE("modulus length:%d", (int) certificate->cert_info.ModulusLength); DEBUG_CERTIFICATE("modulus length:%d", (int) certificate->cert_info.ModulusLength);
} }
} }
@ -584,9 +615,9 @@ BOOL certificate_read_server_x509_certificate_chain(rdpCertificate* certificate,
* @param length certificate length * @param length certificate length
*/ */
BOOL certificate_read_server_certificate(rdpCertificate* certificate, BYTE* server_cert, int length) BOOL certificate_read_server_certificate(rdpCertificate *certificate, BYTE *server_cert, int length)
{ {
wStream* s; wStream *s;
UINT32 dwVersion; UINT32 dwVersion;
BOOL ret; BOOL ret;
@ -594,7 +625,6 @@ BOOL certificate_read_server_certificate(rdpCertificate* certificate, BYTE* serv
return TRUE; return TRUE;
s = Stream_New(server_cert, length); s = Stream_New(server_cert, length);
Stream_Read_UINT32(s, dwVersion); /* dwVersion (4 bytes) */ Stream_Read_UINT32(s, dwVersion); /* dwVersion (4 bytes) */
switch (dwVersion & CERT_CHAIN_VERSION_MASK) switch (dwVersion & CERT_CHAIN_VERSION_MASK)
@ -602,43 +632,42 @@ BOOL certificate_read_server_certificate(rdpCertificate* certificate, BYTE* serv
case CERT_CHAIN_VERSION_1: case CERT_CHAIN_VERSION_1:
ret = certificate_read_server_proprietary_certificate(certificate, s); ret = certificate_read_server_proprietary_certificate(certificate, s);
break; break;
case CERT_CHAIN_VERSION_2: case CERT_CHAIN_VERSION_2:
ret = certificate_read_server_x509_certificate_chain(certificate, s); ret = certificate_read_server_x509_certificate_chain(certificate, s);
break; break;
default: default:
DEBUG_WARN( "invalid certificate chain version:%d\n", dwVersion & CERT_CHAIN_VERSION_MASK); DEBUG_WARN("invalid certificate chain version:%d\n", dwVersion & CERT_CHAIN_VERSION_MASK);
ret = FALSE; ret = FALSE;
break; break;
} }
Stream_Free(s, FALSE); Stream_Free(s, FALSE);
return ret; return ret;
} }
rdpRsaKey* key_new(const char* keyfile) rdpRsaKey *key_new(const char *keyfile)
{ {
FILE* fp; FILE *fp;
RSA* rsa; RSA *rsa;
rdpRsaKey* key; rdpRsaKey *key;
key = (rdpRsaKey *)calloc(1, sizeof(rdpRsaKey)); key = (rdpRsaKey *)calloc(1, sizeof(rdpRsaKey));
if (!key) if (!key)
return NULL; return NULL;
fp = fopen(keyfile, "r"); fp = fopen(keyfile, "r");
if (fp == NULL) if (fp == NULL)
{ {
DEBUG_WARN( "%s: unable to open RSA key file %s: %s.", __FUNCTION__, keyfile, strerror(errno)); DEBUG_WARN("%s: unable to open RSA key file %s: %s.", __FUNCTION__, keyfile, strerror(errno));
goto out_free; goto out_free;
} }
rsa = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL); rsa = PEM_read_RSAPrivateKey(fp, NULL, NULL, NULL);
if (rsa == NULL) if (rsa == NULL)
{ {
DEBUG_WARN( "%s: unable to load RSA key from %s: %s.", __FUNCTION__, keyfile, strerror(errno)); DEBUG_WARN("%s: unable to load RSA key from %s: %s.", __FUNCTION__, keyfile, strerror(errno));
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
fclose(fp); fclose(fp);
goto out_free; goto out_free;
@ -649,46 +678,44 @@ rdpRsaKey* key_new(const char* keyfile)
switch (RSA_check_key(rsa)) switch (RSA_check_key(rsa))
{ {
case 0: case 0:
DEBUG_WARN( "%s: invalid RSA key in %s\n", __FUNCTION__, keyfile); DEBUG_WARN("%s: invalid RSA key in %s\n", __FUNCTION__, keyfile);
goto out_free_rsa; goto out_free_rsa;
case 1: case 1:
/* Valid key. */ /* Valid key. */
break; break;
default: default:
DEBUG_WARN( "%s: unexpected error when checking RSA key from %s: %s.", __FUNCTION__, keyfile, strerror(errno)); DEBUG_WARN("%s: unexpected error when checking RSA key from %s: %s.", __FUNCTION__, keyfile, strerror(errno));
ERR_print_errors_fp(stderr); ERR_print_errors_fp(stderr);
goto out_free_rsa; goto out_free_rsa;
} }
if (BN_num_bytes(rsa->e) > 4) if (BN_num_bytes(rsa->e) > 4)
{ {
DEBUG_WARN( "%s: RSA public exponent too large in %s\n", __FUNCTION__, keyfile); DEBUG_WARN("%s: RSA public exponent too large in %s\n", __FUNCTION__, keyfile);
goto out_free_rsa; goto out_free_rsa;
} }
key->ModulusLength = BN_num_bytes(rsa->n); key->ModulusLength = BN_num_bytes(rsa->n);
key->Modulus = (BYTE *)malloc(key->ModulusLength); key->Modulus = (BYTE *)malloc(key->ModulusLength);
if (!key->Modulus) if (!key->Modulus)
goto out_free_rsa; goto out_free_rsa;
BN_bn2bin(rsa->n, key->Modulus); BN_bn2bin(rsa->n, key->Modulus);
crypto_reverse(key->Modulus, key->ModulusLength); crypto_reverse(key->Modulus, key->ModulusLength);
key->PrivateExponentLength = BN_num_bytes(rsa->d); key->PrivateExponentLength = BN_num_bytes(rsa->d);
key->PrivateExponent = (BYTE *)malloc(key->PrivateExponentLength); key->PrivateExponent = (BYTE *)malloc(key->PrivateExponentLength);
if (!key->PrivateExponent) if (!key->PrivateExponent)
goto out_free_modulus; goto out_free_modulus;
BN_bn2bin(rsa->d, key->PrivateExponent); BN_bn2bin(rsa->d, key->PrivateExponent);
crypto_reverse(key->PrivateExponent, key->PrivateExponentLength); crypto_reverse(key->PrivateExponent, key->PrivateExponentLength);
memset(key->exponent, 0, sizeof(key->exponent)); memset(key->exponent, 0, sizeof(key->exponent));
BN_bn2bin(rsa->e, key->exponent + sizeof(key->exponent) - BN_num_bytes(rsa->e)); BN_bn2bin(rsa->e, key->exponent + sizeof(key->exponent) - BN_num_bytes(rsa->e));
crypto_reverse(key->exponent, sizeof(key->exponent)); crypto_reverse(key->exponent, sizeof(key->exponent));
RSA_free(rsa); RSA_free(rsa);
return key; return key;
out_free_modulus: out_free_modulus:
free(key->Modulus); free(key->Modulus);
out_free_rsa: out_free_rsa:
@ -698,13 +725,14 @@ out_free:
return NULL; return NULL;
} }
void key_free(rdpRsaKey* key) void key_free(rdpRsaKey *key)
{ {
if (!key) if (!key)
return; return;
if (key->Modulus) if (key->Modulus)
free(key->Modulus); free(key->Modulus);
free(key->PrivateExponent); free(key->PrivateExponent);
free(key); free(key);
} }
@ -715,9 +743,9 @@ void key_free(rdpRsaKey* key)
* @return new certificate module * @return new certificate module
*/ */
rdpCertificate* certificate_new() rdpCertificate *certificate_new()
{ {
return (rdpCertificate*) calloc(1, sizeof(rdpCertificate)); return (rdpCertificate *) calloc(1, sizeof(rdpCertificate));
} }
/** /**
@ -725,7 +753,7 @@ rdpCertificate* certificate_new()
* @param certificate certificate module to be freed * @param certificate certificate module to be freed
*/ */
void certificate_free(rdpCertificate* certificate) void certificate_free(rdpCertificate *certificate)
{ {
if (!certificate) if (!certificate)
return; return;

View File

@ -34,12 +34,14 @@
#include "http.h" #include "http.h"
HttpContext* http_context_new() #define TAG "gateway"
HttpContext *http_context_new()
{ {
return (HttpContext *)calloc(1, sizeof(HttpContext)); return (HttpContext *)calloc(1, sizeof(HttpContext));
} }
void http_context_set_method(HttpContext* http_context, char* method) void http_context_set_method(HttpContext *http_context, char *method)
{ {
if (http_context->Method) if (http_context->Method)
free(http_context->Method); free(http_context->Method);
@ -48,7 +50,7 @@ void http_context_set_method(HttpContext* http_context, char* method)
// TODO: check result // TODO: check result
} }
void http_context_set_uri(HttpContext* http_context, char* uri) void http_context_set_uri(HttpContext *http_context, char *uri)
{ {
if (http_context->URI) if (http_context->URI)
free(http_context->URI); free(http_context->URI);
@ -57,7 +59,7 @@ void http_context_set_uri(HttpContext* http_context, char* uri)
// TODO: check result // TODO: check result
} }
void http_context_set_user_agent(HttpContext* http_context, char* user_agent) void http_context_set_user_agent(HttpContext *http_context, char *user_agent)
{ {
if (http_context->UserAgent) if (http_context->UserAgent)
free(http_context->UserAgent); free(http_context->UserAgent);
@ -66,7 +68,7 @@ void http_context_set_user_agent(HttpContext* http_context, char* user_agent)
// TODO: check result // TODO: check result
} }
void http_context_set_host(HttpContext* http_context, char* host) void http_context_set_host(HttpContext *http_context, char *host)
{ {
if (http_context->Host) if (http_context->Host)
free(http_context->Host); free(http_context->Host);
@ -75,7 +77,7 @@ void http_context_set_host(HttpContext* http_context, char* host)
// TODO: check result // TODO: check result
} }
void http_context_set_accept(HttpContext* http_context, char* accept) void http_context_set_accept(HttpContext *http_context, char *accept)
{ {
if (http_context->Accept) if (http_context->Accept)
free(http_context->Accept); free(http_context->Accept);
@ -84,7 +86,7 @@ void http_context_set_accept(HttpContext* http_context, char* accept)
// TODO: check result // TODO: check result
} }
void http_context_set_cache_control(HttpContext* http_context, char* cache_control) void http_context_set_cache_control(HttpContext *http_context, char *cache_control)
{ {
if (http_context->CacheControl) if (http_context->CacheControl)
free(http_context->CacheControl); free(http_context->CacheControl);
@ -93,7 +95,7 @@ void http_context_set_cache_control(HttpContext* http_context, char* cache_contr
// TODO: check result // TODO: check result
} }
void http_context_set_connection(HttpContext* http_context, char* connection) void http_context_set_connection(HttpContext *http_context, char *connection)
{ {
if (http_context->Connection) if (http_context->Connection)
free(http_context->Connection); free(http_context->Connection);
@ -102,7 +104,7 @@ void http_context_set_connection(HttpContext* http_context, char* connection)
// TODO: check result // TODO: check result
} }
void http_context_set_pragma(HttpContext* http_context, char* pragma) void http_context_set_pragma(HttpContext *http_context, char *pragma)
{ {
if (http_context->Pragma) if (http_context->Pragma)
free(http_context->Pragma); free(http_context->Pragma);
@ -111,7 +113,7 @@ void http_context_set_pragma(HttpContext* http_context, char* pragma)
// TODO: check result // TODO: check result
} }
void http_context_free(HttpContext* http_context) void http_context_free(HttpContext *http_context)
{ {
if (http_context != NULL) if (http_context != NULL)
{ {
@ -127,7 +129,7 @@ void http_context_free(HttpContext* http_context)
} }
} }
void http_request_set_method(HttpRequest* http_request, char* method) void http_request_set_method(HttpRequest *http_request, char *method)
{ {
if (http_request->Method) if (http_request->Method)
free(http_request->Method); free(http_request->Method);
@ -136,7 +138,7 @@ void http_request_set_method(HttpRequest* http_request, char* method)
// TODO: check result // TODO: check result
} }
void http_request_set_uri(HttpRequest* http_request, char* uri) void http_request_set_uri(HttpRequest *http_request, char *uri)
{ {
if (http_request->URI) if (http_request->URI)
free(http_request->URI); free(http_request->URI);
@ -145,7 +147,7 @@ void http_request_set_uri(HttpRequest* http_request, char* uri)
// TODO: check result // TODO: check result
} }
void http_request_set_auth_scheme(HttpRequest* http_request, char* auth_scheme) void http_request_set_auth_scheme(HttpRequest *http_request, char *auth_scheme)
{ {
if (http_request->AuthScheme) if (http_request->AuthScheme)
free(http_request->AuthScheme); free(http_request->AuthScheme);
@ -154,7 +156,7 @@ void http_request_set_auth_scheme(HttpRequest* http_request, char* auth_scheme)
// TODO: check result // TODO: check result
} }
void http_request_set_auth_param(HttpRequest* http_request, char* auth_param) void http_request_set_auth_param(HttpRequest *http_request, char *auth_param)
{ {
if (http_request->AuthParam) if (http_request->AuthParam)
free(http_request->AuthParam); free(http_request->AuthParam);
@ -163,73 +165,73 @@ void http_request_set_auth_param(HttpRequest* http_request, char* auth_param)
// TODO: check result // TODO: check result
} }
char* http_encode_body_line(char* param, char* value) char *http_encode_body_line(char *param, char *value)
{ {
char* line; char *line;
int length; int length;
length = strlen(param) + strlen(value) + 2; length = strlen(param) + strlen(value) + 2;
line = (char*) malloc(length + 1); line = (char *) malloc(length + 1);
if (!line) if (!line)
return NULL; return NULL;
sprintf_s(line, length + 1, "%s: %s", param, value);
sprintf_s(line, length + 1, "%s: %s", param, value);
return line; return line;
} }
char* http_encode_content_length_line(int ContentLength) char *http_encode_content_length_line(int ContentLength)
{ {
char* line; char *line;
int length; int length;
char str[32]; char str[32];
_itoa_s(ContentLength, str, sizeof(str), 10); _itoa_s(ContentLength, str, sizeof(str), 10);
length = strlen("Content-Length") + strlen(str) + 2; length = strlen("Content-Length") + strlen(str) + 2;
line = (char *)malloc(length + 1); line = (char *)malloc(length + 1);
if (!line) if (!line)
return NULL; return NULL;
sprintf_s(line, length + 1, "Content-Length: %s", str);
sprintf_s(line, length + 1, "Content-Length: %s", str);
return line; return line;
} }
char* http_encode_header_line(char* Method, char* URI) char *http_encode_header_line(char *Method, char *URI)
{ {
char* line; char *line;
int length; int length;
length = strlen("HTTP/1.1") + strlen(Method) + strlen(URI) + 2; length = strlen("HTTP/1.1") + strlen(Method) + strlen(URI) + 2;
line = (char *)malloc(length + 1); line = (char *)malloc(length + 1);
if (!line) if (!line)
return NULL; return NULL;
sprintf_s(line, length + 1, "%s %s HTTP/1.1", Method, URI); sprintf_s(line, length + 1, "%s %s HTTP/1.1", Method, URI);
return line; return line;
} }
char* http_encode_authorization_line(char* AuthScheme, char* AuthParam) char *http_encode_authorization_line(char *AuthScheme, char *AuthParam)
{ {
char* line; char *line;
int length; int length;
length = strlen("Authorization") + strlen(AuthScheme) + strlen(AuthParam) + 3; length = strlen("Authorization") + strlen(AuthScheme) + strlen(AuthParam) + 3;
line = (char*) malloc(length + 1); line = (char *) malloc(length + 1);
if (!line) if (!line)
return NULL; return NULL;
sprintf_s(line, length + 1, "Authorization: %s %s", AuthScheme, AuthParam);
sprintf_s(line, length + 1, "Authorization: %s %s", AuthScheme, AuthParam);
return line; return line;
} }
wStream* http_request_write(HttpContext* http_context, HttpRequest* http_request) wStream *http_request_write(HttpContext *http_context, HttpRequest *http_request)
{ {
int i, count; int i, count;
char **lines; char **lines;
wStream* s; wStream *s;
int length = 0; int length = 0;
count = 9; count = 9;
lines = (char **)calloc(count, sizeof(char *)); lines = (char **)calloc(count, sizeof(char *));
if (!lines) if (!lines)
return NULL; return NULL;
@ -252,12 +254,14 @@ wStream* http_request_write(HttpContext* http_context, HttpRequest* http_request
if (http_request->Authorization != NULL) if (http_request->Authorization != NULL)
{ {
lines[8] = http_encode_body_line("Authorization", http_request->Authorization); lines[8] = http_encode_body_line("Authorization", http_request->Authorization);
if (!lines[8]) if (!lines[8])
goto out_free; goto out_free;
} }
else if ((http_request->AuthScheme != NULL) && (http_request->AuthParam != NULL)) else if ((http_request->AuthScheme != NULL) && (http_request->AuthParam != NULL))
{ {
lines[8] = http_encode_authorization_line(http_request->AuthScheme, http_request->AuthParam); lines[8] = http_encode_authorization_line(http_request->AuthScheme, http_request->AuthParam);
if (!lines[8]) if (!lines[8])
goto out_free; goto out_free;
} }
@ -266,10 +270,11 @@ wStream* http_request_write(HttpContext* http_context, HttpRequest* http_request
{ {
length += (strlen(lines[i]) + 2); /* add +2 for each '\r\n' character */ length += (strlen(lines[i]) + 2); /* add +2 for each '\r\n' character */
} }
length += 2; /* empty line "\r\n" at end of header */ length += 2; /* empty line "\r\n" at end of header */
length += 1; /* null terminator */ length += 1; /* null terminator */
s = Stream_New(NULL, length); s = Stream_New(NULL, length);
if (!s) if (!s)
goto out_free; goto out_free;
@ -279,73 +284,79 @@ wStream* http_request_write(HttpContext* http_context, HttpRequest* http_request
Stream_Write(s, "\r\n", 2); Stream_Write(s, "\r\n", 2);
free(lines[i]); free(lines[i]);
} }
Stream_Write(s, "\r\n", 2); Stream_Write(s, "\r\n", 2);
free(lines); free(lines);
Stream_Write(s, "\0", 1); /* append null terminator */ Stream_Write(s, "\0", 1); /* append null terminator */
Stream_Rewind(s, 1); /* don't include null terminator in length */ Stream_Rewind(s, 1); /* don't include null terminator in length */
Stream_Length(s) = Stream_GetPosition(s); Stream_Length(s) = Stream_GetPosition(s);
return s; return s;
out_free: out_free:
for (i = 0; i < 9; i++) for (i = 0; i < 9; i++)
{ {
if (lines[i]) if (lines[i])
free(lines[i]); free(lines[i]);
} }
free(lines); free(lines);
return NULL; return NULL;
} }
HttpRequest* http_request_new() HttpRequest *http_request_new()
{ {
return (HttpRequest*) calloc(1, sizeof(HttpRequest)); return (HttpRequest *) calloc(1, sizeof(HttpRequest));
} }
void http_request_free(HttpRequest* http_request) void http_request_free(HttpRequest *http_request)
{ {
if (!http_request) if (!http_request)
return; return;
if (http_request->AuthParam) if (http_request->AuthParam)
free(http_request->AuthParam); free(http_request->AuthParam);
if (http_request->AuthScheme) if (http_request->AuthScheme)
free(http_request->AuthScheme); free(http_request->AuthScheme);
if (http_request->Authorization) if (http_request->Authorization)
free(http_request->Authorization); free(http_request->Authorization);
free(http_request->Content); free(http_request->Content);
free(http_request->Method); free(http_request->Method);
free(http_request->URI); free(http_request->URI);
free(http_request); free(http_request);
} }
BOOL http_response_parse_header_status_line(HttpResponse* http_response, char* status_line) BOOL http_response_parse_header_status_line(HttpResponse *http_response, char *status_line)
{ {
char* separator; char *separator;
char* status_code; char *status_code;
char* reason_phrase; char *reason_phrase;
separator = strchr(status_line, ' '); separator = strchr(status_line, ' ');
if (!separator) if (!separator)
return FALSE; return FALSE;
status_code = separator + 1; status_code = separator + 1;
separator = strchr(status_code, ' '); separator = strchr(status_code, ' ');
if (!separator) if (!separator)
return FALSE; return FALSE;
reason_phrase = separator + 1;
reason_phrase = separator + 1;
*separator = '\0'; *separator = '\0';
http_response->StatusCode = atoi(status_code); http_response->StatusCode = atoi(status_code);
http_response->ReasonPhrase = _strdup(reason_phrase); http_response->ReasonPhrase = _strdup(reason_phrase);
if (!http_response->ReasonPhrase) if (!http_response->ReasonPhrase)
return FALSE; return FALSE;
*separator = ' '; *separator = ' ';
return TRUE; return TRUE;
} }
BOOL http_response_parse_header_field(HttpResponse* http_response, char* name, char* value) BOOL http_response_parse_header_field(HttpResponse *http_response, char *name, char *value)
{ {
if (_stricmp(name, "Content-Length") == 0) if (_stricmp(name, "Content-Length") == 0)
{ {
@ -353,9 +364,8 @@ BOOL http_response_parse_header_field(HttpResponse* http_response, char* name, c
} }
else if (_stricmp(name, "WWW-Authenticate") == 0) else if (_stricmp(name, "WWW-Authenticate") == 0)
{ {
char* separator; char *separator;
char *authScheme, *authValue; char *authScheme, *authValue;
separator = strchr(value, ' '); separator = strchr(value, ' ');
if (separator != NULL) if (separator != NULL)
@ -367,34 +377,38 @@ BOOL http_response_parse_header_field(HttpResponse* http_response, char* name, c
* opaque="5ccc069c403ebaf9f0171e9517f40e41" * opaque="5ccc069c403ebaf9f0171e9517f40e41"
*/ */
*separator = '\0'; *separator = '\0';
authScheme = _strdup(value); authScheme = _strdup(value);
authValue = _strdup(separator + 1); authValue = _strdup(separator + 1);
if (!authScheme || !authValue) if (!authScheme || !authValue)
return FALSE; return FALSE;
*separator = ' '; *separator = ' ';
} }
else else
{ {
authScheme = _strdup(value); authScheme = _strdup(value);
if (!authScheme) if (!authScheme)
return FALSE; return FALSE;
authValue = NULL; authValue = NULL;
} }
return ListDictionary_Add(http_response->Authenticates, authScheme, authValue); return ListDictionary_Add(http_response->Authenticates, authScheme, authValue);
} }
return TRUE; return TRUE;
} }
BOOL http_response_parse_header(HttpResponse* http_response) BOOL http_response_parse_header(HttpResponse *http_response)
{ {
int count; int count;
char* line; char *line;
char* name; char *name;
char* value; char *value;
char* colon_pos; char *colon_pos;
char* end_of_header; char *end_of_header;
char end_of_header_char; char end_of_header_char;
char c; char c;
@ -410,7 +424,6 @@ BOOL http_response_parse_header(HttpResponse* http_response)
for (count = 1; count < http_response->count; count++) for (count = 1; count < http_response->count; count++)
{ {
line = http_response->lines[count]; line = http_response->lines[count];
/** /**
* name end_of_header * name end_of_header
* | | * | |
@ -421,21 +434,24 @@ BOOL http_response_parse_header(HttpResponse* http_response)
* colon_pos value * colon_pos value
*/ */
colon_pos = strchr(line, ':'); colon_pos = strchr(line, ':');
if ((colon_pos == NULL) || (colon_pos == line)) if ((colon_pos == NULL) || (colon_pos == line))
return FALSE; return FALSE;
/* retrieve the position just after header name */ /* retrieve the position just after header name */
for(end_of_header = colon_pos; end_of_header != line; end_of_header--) for (end_of_header = colon_pos; end_of_header != line; end_of_header--)
{ {
c = end_of_header[-1]; c = end_of_header[-1];
if (c != ' ' && c != '\t' && c != ':') if (c != ' ' && c != '\t' && c != ':')
break; break;
} }
if (end_of_header == line) if (end_of_header == line)
return FALSE; return FALSE;
end_of_header_char = *end_of_header; end_of_header_char = *end_of_header;
*end_of_header = '\0'; *end_of_header = '\0';
name = line; name = line;
/* eat space and tabs before header value */ /* eat space and tabs before header value */
@ -450,39 +466,42 @@ BOOL http_response_parse_header(HttpResponse* http_response)
*end_of_header = end_of_header_char; *end_of_header = end_of_header_char;
} }
return TRUE; return TRUE;
} }
void http_response_print(HttpResponse* http_response) void http_response_print(HttpResponse *http_response)
{ {
int i; int i;
for (i = 0; i < http_response->count; i++) for (i = 0; i < http_response->count; i++)
{ {
DEBUG_WARN( "%s\n", http_response->lines[i]); DEBUG_WARN("%s\n", http_response->lines[i]);
} }
DEBUG_WARN( "\n");
DEBUG_WARN("\n");
} }
HttpResponse* http_response_recv(rdpTls* tls) HttpResponse *http_response_recv(rdpTls *tls)
{ {
BYTE* p; BYTE *p;
int nbytes; int nbytes;
int length; int length;
int status; int status;
BYTE* buffer; BYTE *buffer;
char* content; char *content;
char* header_end; char *header_end;
HttpResponse* http_response; HttpResponse *http_response;
nbytes = 0; nbytes = 0;
length = 10000; length = 10000;
content = NULL; content = NULL;
buffer = calloc(length, 1); buffer = calloc(length, 1);
if (!buffer) if (!buffer)
return NULL; return NULL;
http_response = http_response_new(); http_response = http_response_new();
if (!http_response) if (!http_response)
goto out_free; goto out_free;
@ -494,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
while (nbytes < 5) while (nbytes < 5)
{ {
status = BIO_read(tls->bio, p, length - nbytes); status = BIO_read(tls->bio, p, length - nbytes);
if (status <= 0) if (status <= 0)
{ {
if (!BIO_should_retry(tls->bio)) if (!BIO_should_retry(tls->bio))
@ -508,15 +527,15 @@ HttpResponse* http_response_recv(rdpTls* tls)
VALGRIND_MAKE_MEM_DEFINED(p, status); VALGRIND_MAKE_MEM_DEFINED(p, status);
#endif #endif
nbytes += status; nbytes += status;
p = (BYTE*) &buffer[nbytes]; p = (BYTE *) &buffer[nbytes];
} }
header_end = strstr((char*) buffer, "\r\n\r\n"); header_end = strstr((char *) buffer, "\r\n\r\n");
if (!header_end) if (!header_end)
{ {
DEBUG_WARN( "%s: invalid response:\n", __FUNCTION__); DEBUG_WARN("%s: invalid response:\n", __FUNCTION__);
winpr_HexDump(buffer, status); winpr_HexDump(TAG, WLOG_ERROR, buffer, status);
goto out_error; goto out_error;
} }
@ -525,14 +544,12 @@ HttpResponse* http_response_recv(rdpTls* tls)
if (header_end != NULL) if (header_end != NULL)
{ {
int count; int count;
char* line; char *line;
header_end[0] = '\0'; header_end[0] = '\0';
header_end[1] = '\0'; header_end[1] = '\0';
content = header_end + 2; content = header_end + 2;
count = 0; count = 0;
line = (char*) buffer; line = (char *) buffer;
while ((line = strstr(line, "\r\n")) != NULL) while ((line = strstr(line, "\r\n")) != NULL)
{ {
@ -541,19 +558,22 @@ HttpResponse* http_response_recv(rdpTls* tls)
} }
http_response->count = count; http_response->count = count;
if (count) if (count)
{ {
http_response->lines = (char **)calloc(http_response->count, sizeof(char *)); http_response->lines = (char **)calloc(http_response->count, sizeof(char *));
if (!http_response->lines) if (!http_response->lines)
goto out_error; goto out_error;
} }
count = 0; count = 0;
line = strtok((char*) buffer, "\r\n"); line = strtok((char *) buffer, "\r\n");
while (line != NULL) while (line != NULL)
{ {
http_response->lines[count] = _strdup(line); http_response->lines[count] = _strdup(line);
if (!http_response->lines[count]) if (!http_response->lines[count])
goto out_error; goto out_error;
@ -565,9 +585,11 @@ HttpResponse* http_response_recv(rdpTls* tls)
goto out_error; goto out_error;
http_response->bodyLen = nbytes - (content - (char *)buffer); http_response->bodyLen = nbytes - (content - (char *)buffer);
if (http_response->bodyLen > 0) if (http_response->bodyLen > 0)
{ {
http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen); http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
if (!http_response->BodyContent) if (!http_response->BodyContent)
goto out_error; goto out_error;
@ -581,14 +603,12 @@ HttpResponse* http_response_recv(rdpTls* tls)
{ {
length *= 2; length *= 2;
buffer = realloc(buffer, length); buffer = realloc(buffer, length);
p = (BYTE*) &buffer[nbytes]; p = (BYTE *) &buffer[nbytes];
} }
} }
free(buffer); free(buffer);
return http_response; return http_response;
out_error: out_error:
http_response_free(http_response); http_response_free(http_response);
out_free: out_free:
@ -608,12 +628,14 @@ static void string_free(void *obj1)
{ {
if (!obj1) if (!obj1)
return; return;
free(obj1); free(obj1);
} }
HttpResponse* http_response_new() HttpResponse *http_response_new()
{ {
HttpResponse *ret = (HttpResponse *)calloc(1, sizeof(HttpResponse)); HttpResponse *ret = (HttpResponse *)calloc(1, sizeof(HttpResponse));
if (!ret) if (!ret)
return NULL; return NULL;
@ -625,7 +647,7 @@ HttpResponse* http_response_new()
return ret; return ret;
} }
void http_response_free(HttpResponse* http_response) void http_response_free(HttpResponse *http_response)
{ {
int i; int i;
@ -636,9 +658,7 @@ void http_response_free(HttpResponse* http_response)
free(http_response->lines[i]); free(http_response->lines[i]);
free(http_response->lines); free(http_response->lines);
free(http_response->ReasonPhrase); free(http_response->ReasonPhrase);
ListDictionary_Free(http_response->Authenticates); ListDictionary_Free(http_response->Authenticates);
if (http_response->ContentLength > 0) if (http_response->ContentLength > 0)

View File

@ -37,11 +37,12 @@
#include "rpc_client.h" #include "rpc_client.h"
#include "../rdp.h" #include "../rdp.h"
#define TAG "gateway"
#define SYNCHRONOUS_TIMEOUT 5000 #define SYNCHRONOUS_TIMEOUT 5000
wStream* rpc_client_fragment_pool_take(rdpRpc* rpc) wStream *rpc_client_fragment_pool_take(rdpRpc *rpc)
{ {
wStream* fragment = NULL; wStream *fragment = NULL;
if (WaitForSingleObject(Queue_Event(rpc->client->FragmentPool), 0) == WAIT_OBJECT_0) if (WaitForSingleObject(Queue_Event(rpc->client->FragmentPool), 0) == WAIT_OBJECT_0)
fragment = Queue_Dequeue(rpc->client->FragmentPool); fragment = Queue_Dequeue(rpc->client->FragmentPool);
@ -52,15 +53,15 @@ wStream* rpc_client_fragment_pool_take(rdpRpc* rpc)
return fragment; return fragment;
} }
int rpc_client_fragment_pool_return(rdpRpc* rpc, wStream* fragment) int rpc_client_fragment_pool_return(rdpRpc *rpc, wStream *fragment)
{ {
Queue_Enqueue(rpc->client->FragmentPool, fragment); Queue_Enqueue(rpc->client->FragmentPool, fragment);
return 0; return 0;
} }
RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc) RPC_PDU *rpc_client_receive_pool_take(rdpRpc *rpc)
{ {
RPC_PDU* pdu = NULL; RPC_PDU *pdu = NULL;
if (WaitForSingleObject(Queue_Event(rpc->client->ReceivePool), 0) == WAIT_OBJECT_0) if (WaitForSingleObject(Queue_Event(rpc->client->ReceivePool), 0) == WAIT_OBJECT_0)
pdu = Queue_Dequeue(rpc->client->ReceivePool); pdu = Queue_Dequeue(rpc->client->ReceivePool);
@ -68,9 +69,12 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
if (!pdu) if (!pdu)
{ {
pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU)); pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
if (!pdu) if (!pdu)
return NULL; return NULL;
pdu->s = Stream_New(NULL, rpc->max_recv_frag); pdu->s = Stream_New(NULL, rpc->max_recv_frag);
if (!pdu->s) if (!pdu->s)
{ {
free(pdu); free(pdu);
@ -80,68 +84,61 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
pdu->CallId = 0; pdu->CallId = 0;
pdu->Flags = 0; pdu->Flags = 0;
Stream_Length(pdu->s) = 0; Stream_Length(pdu->s) = 0;
Stream_SetPosition(pdu->s, 0); Stream_SetPosition(pdu->s, 0);
return pdu; return pdu;
} }
int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu) int rpc_client_receive_pool_return(rdpRpc *rpc, RPC_PDU *pdu)
{ {
return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1; return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
} }
int rpc_client_on_fragment_received_event(rdpRpc* rpc) int rpc_client_on_fragment_received_event(rdpRpc *rpc)
{ {
BYTE* buffer; BYTE *buffer;
UINT32 StubOffset; UINT32 StubOffset;
UINT32 StubLength; UINT32 StubLength;
wStream* fragment; wStream *fragment;
rpcconn_hdr_t* header; rpcconn_hdr_t *header;
freerdp* instance; freerdp *instance;
instance = (freerdp *)rpc->transport->settings->instance; instance = (freerdp *)rpc->transport->settings->instance;
if (!rpc->client->pdu) if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc); rpc->client->pdu = rpc_client_receive_pool_take(rpc);
fragment = Queue_Dequeue(rpc->client->FragmentQueue); fragment = Queue_Dequeue(rpc->client->FragmentQueue);
buffer = (BYTE *) Stream_Buffer(fragment);
buffer = (BYTE*) Stream_Buffer(fragment); header = (rpcconn_hdr_t *) Stream_Buffer(fragment);
header = (rpcconn_hdr_t*) Stream_Buffer(fragment);
if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED) if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{ {
rpc->client->pdu->Flags = 0; rpc->client->pdu->Flags = 0;
rpc->client->pdu->CallId = header->common.call_id; rpc->client->pdu->CallId = header->common.call_id;
Stream_EnsureCapacity(rpc->client->pdu->s, Stream_Length(fragment)); Stream_EnsureCapacity(rpc->client->pdu->s, Stream_Length(fragment));
Stream_Write(rpc->client->pdu->s, buffer, Stream_Length(fragment)); Stream_Write(rpc->client->pdu->s, buffer, Stream_Length(fragment));
Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s); Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s);
rpc_client_fragment_pool_return(rpc, fragment); rpc_client_fragment_pool_return(rpc, fragment);
Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu); Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu);
SetEvent(rpc->transport->ReceiveEvent); SetEvent(rpc->transport->ReceiveEvent);
rpc->client->pdu = NULL; rpc->client->pdu = NULL;
return 0; return 0;
} }
switch (header->common.ptype) switch (header->common.ptype)
{ {
case PTYPE_RTS: case PTYPE_RTS:
if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED) if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
{ {
DEBUG_WARN( "%s: warning: unhandled RTS PDU\n", __FUNCTION__); DEBUG_WARN("%s: warning: unhandled RTS PDU\n", __FUNCTION__);
return 0; return 0;
} }
DEBUG_WARN( "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
DEBUG_WARN("%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length); rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
rpc_client_fragment_pool_return(rpc, fragment); rpc_client_fragment_pool_return(rpc, fragment);
return 0; return 0;
case PTYPE_FAULT: case PTYPE_FAULT:
rpc_recv_fault_pdu(header); rpc_recv_fault_pdu(header);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL); Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
@ -149,7 +146,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
case PTYPE_RESPONSE: case PTYPE_RESPONSE:
break; break;
default: default:
DEBUG_WARN( "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype); DEBUG_WARN("%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL); Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1; return -1;
} }
@ -159,7 +156,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength)) if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{ {
DEBUG_WARN( "%s: expected stub\n", __FUNCTION__); DEBUG_WARN("%s: expected stub\n", __FUNCTION__);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL); Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1; return -1;
} }
@ -173,10 +170,8 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
if ((header->common.call_id == rpc->PipeCallId) && (header->common.pfc_flags & PFC_LAST_FRAG)) if ((header->common.call_id == rpc->PipeCallId) && (header->common.pfc_flags & PFC_LAST_FRAG))
{ {
TerminateEventArgs e; TerminateEventArgs e;
instance->context->rdp->disconnect = TRUE; instance->context->rdp->disconnect = TRUE;
rpc->transport->tsg->state = TSG_STATE_TUNNEL_CLOSE_PENDING; rpc->transport->tsg->state = TSG_STATE_TUNNEL_CLOSE_PENDING;
EventArgsInit(&e, "freerdp"); EventArgsInit(&e, "freerdp");
e.code = 0; e.code = 0;
PubSub_OnTerminate(instance->context->pubSub, instance->context, &e); PubSub_OnTerminate(instance->context->pubSub, instance->context, &e);
@ -187,21 +182,20 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
} }
Stream_EnsureCapacity(rpc->client->pdu->s, header->response.alloc_hint); Stream_EnsureCapacity(rpc->client->pdu->s, header->response.alloc_hint);
buffer = (BYTE*) Stream_Buffer(fragment); buffer = (BYTE *) Stream_Buffer(fragment);
header = (rpcconn_hdr_t*) Stream_Buffer(fragment); header = (rpcconn_hdr_t *) Stream_Buffer(fragment);
if (rpc->StubFragCount == 0) if (rpc->StubFragCount == 0)
rpc->StubCallId = header->common.call_id; rpc->StubCallId = header->common.call_id;
if (rpc->StubCallId != header->common.call_id) if (rpc->StubCallId != header->common.call_id)
{ {
DEBUG_WARN( "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__, DEBUG_WARN("%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
rpc->StubCallId, header->common.call_id, rpc->StubFragCount); rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
} }
Stream_Write(rpc->client->pdu->s, &buffer[StubOffset], StubLength); Stream_Write(rpc->client->pdu->s, &buffer[StubOffset], StubLength);
rpc->StubFragCount++; rpc->StubFragCount++;
rpc_client_fragment_pool_return(rpc, fragment); rpc_client_fragment_pool_return(rpc, fragment);
if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < (rpc->ReceiveWindow / 2)) if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow < (rpc->ReceiveWindow / 2))
@ -220,27 +214,22 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
{ {
rpc->client->pdu->Flags = RPC_PDU_FLAG_STUB; rpc->client->pdu->Flags = RPC_PDU_FLAG_STUB;
rpc->client->pdu->CallId = rpc->StubCallId; rpc->client->pdu->CallId = rpc->StubCallId;
Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s); Stream_Length(rpc->client->pdu->s) = Stream_GetPosition(rpc->client->pdu->s);
rpc->StubFragCount = 0; rpc->StubFragCount = 0;
rpc->StubCallId = 0; rpc->StubCallId = 0;
Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu); Queue_Enqueue(rpc->client->ReceiveQueue, rpc->client->pdu);
rpc->client->pdu = NULL; rpc->client->pdu = NULL;
return 0; return 0;
} }
return 0; return 0;
} }
int rpc_client_on_read_event(rdpRpc* rpc) int rpc_client_on_read_event(rdpRpc *rpc)
{ {
int position; int position;
int status = -1; int status = -1;
rpcconn_common_hdr_t* header; rpcconn_common_hdr_t *header;
while (1) while (1)
{ {
@ -252,11 +241,11 @@ int rpc_client_on_read_event(rdpRpc* rpc)
while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{ {
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag)); RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0) if (status < 0)
{ {
DEBUG_WARN( "rpc_client_frag_read: error reading header\n"); DEBUG_WARN("rpc_client_frag_read: error reading header\n");
return -1; return -1;
} }
@ -269,25 +258,24 @@ int rpc_client_on_read_event(rdpRpc* rpc)
if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH) if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
return status; return status;
header = (rpcconn_common_hdr_t *) Stream_Buffer(rpc->client->RecvFrag);
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
if (header->frag_length > rpc->max_recv_frag) if (header->frag_length > rpc->max_recv_frag)
{ {
DEBUG_WARN( "rpc_client_frag_read: invalid fragment size: %d (max: %d)\n", DEBUG_WARN("rpc_client_frag_read: invalid fragment size: %d (max: %d)\n",
header->frag_length, rpc->max_recv_frag); header->frag_length, rpc->max_recv_frag);
winpr_HexDump(Stream_Buffer(rpc->client->RecvFrag), Stream_GetPosition(rpc->client->RecvFrag)); winpr_HexDump(TAG, WLOG_ERROR, Stream_Buffer(rpc->client->RecvFrag), Stream_GetPosition(rpc->client->RecvFrag));
return -1; return -1;
} }
while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length) while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
{ {
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag), status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
header->frag_length - Stream_GetPosition(rpc->client->RecvFrag)); header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0) if (status < 0)
{ {
DEBUG_WARN( "%s: error reading fragment body\n", __FUNCTION__); DEBUG_WARN("%s: error reading fragment body\n", __FUNCTION__);
return -1; return -1;
} }
@ -305,10 +293,8 @@ int rpc_client_on_read_event(rdpRpc* rpc)
if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length) if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
{ {
/* complete fragment received */ /* complete fragment received */
Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag); Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
Stream_SetPosition(rpc->client->RecvFrag, 0); Stream_SetPosition(rpc->client->RecvFrag, 0);
Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag); Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
rpc->client->RecvFrag = NULL; rpc->client->RecvFrag = NULL;
@ -325,60 +311,57 @@ int rpc_client_on_read_event(rdpRpc* rpc)
* http://msdn.microsoft.com/en-us/library/gg593159/ * http://msdn.microsoft.com/en-us/library/gg593159/
*/ */
RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, UINT32 CallId) RpcClientCall *rpc_client_call_find_by_id(rdpRpc *rpc, UINT32 CallId)
{ {
int index; int index;
int count; int count;
RpcClientCall* clientCall; RpcClientCall *clientCall;
ArrayList_Lock(rpc->client->ClientCallList); ArrayList_Lock(rpc->client->ClientCallList);
clientCall = NULL; clientCall = NULL;
count = ArrayList_Count(rpc->client->ClientCallList); count = ArrayList_Count(rpc->client->ClientCallList);
for (index = 0; index < count; index++) for (index = 0; index < count; index++)
{ {
clientCall = (RpcClientCall*) ArrayList_GetItem(rpc->client->ClientCallList, index); clientCall = (RpcClientCall *) ArrayList_GetItem(rpc->client->ClientCallList, index);
if (clientCall->CallId == CallId) if (clientCall->CallId == CallId)
break; break;
} }
ArrayList_Unlock(rpc->client->ClientCallList); ArrayList_Unlock(rpc->client->ClientCallList);
return clientCall; return clientCall;
} }
RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum) RpcClientCall *rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
{ {
RpcClientCall* clientCall; RpcClientCall *clientCall;
clientCall = (RpcClientCall *) malloc(sizeof(RpcClientCall));
clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
if (!clientCall) if (!clientCall)
return NULL; return NULL;
clientCall->CallId = CallId; clientCall->CallId = CallId;
clientCall->OpNum = OpNum; clientCall->OpNum = OpNum;
clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS; clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
return clientCall; return clientCall;
} }
void rpc_client_call_free(RpcClientCall* clientCall) void rpc_client_call_free(RpcClientCall *clientCall)
{ {
free(clientCall); free(clientCall);
} }
int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length) int rpc_send_enqueue_pdu(rdpRpc *rpc, BYTE *buffer, UINT32 length)
{ {
RPC_PDU* pdu; RPC_PDU *pdu;
int status; int status;
pdu = (RPC_PDU *) malloc(sizeof(RPC_PDU));
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
if (!pdu) if (!pdu)
return -1; return -1;
pdu->s = Stream_New(buffer, length); pdu->s = Stream_New(buffer, length);
if (!pdu->s) if (!pdu->s)
goto out_free; goto out_free;
@ -388,9 +371,10 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
if (rpc->client->SynchronousSend) if (rpc->client->SynchronousSend)
{ {
status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT); status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
if (status == WAIT_TIMEOUT) if (status == WAIT_TIMEOUT)
{ {
DEBUG_WARN( "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent); DEBUG_WARN("%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
return -1; return -1;
} }
@ -398,7 +382,6 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
} }
return 0; return 0;
out_free_stream: out_free_stream:
Stream_Free(pdu->s, TRUE); Stream_Free(pdu->s, TRUE);
out_free: out_free:
@ -406,27 +389,24 @@ out_free:
return -1; return -1;
} }
int rpc_send_dequeue_pdu(rdpRpc* rpc) int rpc_send_dequeue_pdu(rdpRpc *rpc)
{ {
int status; int status;
RPC_PDU* pdu; RPC_PDU *pdu;
RpcClientCall* clientCall; RpcClientCall *clientCall;
rpcconn_common_hdr_t* header; rpcconn_common_hdr_t *header;
RpcInChannel *inChannel; RpcInChannel *inChannel;
pdu = (RPC_PDU *) Queue_Dequeue(rpc->client->SendQueue);
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
if (!pdu) if (!pdu)
return 0; return 0;
inChannel = rpc->VirtualConnection->DefaultInChannel; inChannel = rpc->VirtualConnection->DefaultInChannel;
WaitForSingleObject(inChannel->Mutex, INFINITE); WaitForSingleObject(inChannel->Mutex, INFINITE);
status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)); status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
header = (rpcconn_common_hdr_t *) Stream_Buffer(pdu->s);
header = (rpcconn_common_hdr_t*) Stream_Buffer(pdu->s);
clientCall = rpc_client_call_find_by_id(rpc, header->call_id); clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED; clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
ReleaseMutex(inChannel->Mutex); ReleaseMutex(inChannel->Mutex);
/* /*
@ -451,18 +431,17 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
return status; return status;
} }
RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc) RPC_PDU *rpc_recv_dequeue_pdu(rdpRpc *rpc)
{ {
RPC_PDU* pdu; RPC_PDU *pdu;
DWORD dwMilliseconds; DWORD dwMilliseconds;
DWORD result; DWORD result;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT) if (result == WAIT_TIMEOUT)
{ {
DEBUG_WARN( "%s: timed out waiting for receive event\n", __FUNCTION__); DEBUG_WARN("%s: timed out waiting for receive event\n", __FUNCTION__);
return NULL; return NULL;
} }
@ -470,51 +449,47 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
return NULL; return NULL;
pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue); pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
#ifdef WITH_DEBUG_TSG #ifdef WITH_DEBUG_TSG
if (pdu) if (pdu)
{ {
DEBUG_WARN( "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId); DEBUG_WARN("Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s)); winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
DEBUG_WARN( "\n"); DEBUG_WARN("\n");
} }
else else
{ {
DEBUG_WARN( "Receiving a NULL PDU\n"); DEBUG_WARN("Receiving a NULL PDU\n");
} }
#endif
#endif
return pdu; return pdu;
} }
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc) RPC_PDU *rpc_recv_peek_pdu(rdpRpc *rpc)
{ {
DWORD dwMilliseconds; DWORD dwMilliseconds;
DWORD result; DWORD result;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0; dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds); result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result != WAIT_OBJECT_0) if (result != WAIT_OBJECT_0)
return NULL; return NULL;
return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue); return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
} }
static void* rpc_client_thread(void* arg) static void *rpc_client_thread(void *arg)
{ {
rdpRpc* rpc; rdpRpc *rpc;
DWORD status; DWORD status;
DWORD nCount; DWORD nCount;
HANDLE events[3]; HANDLE events[3];
HANDLE ReadEvent; HANDLE ReadEvent;
int fd; int fd;
rpc = (rdpRpc *) arg;
rpc = (rdpRpc*) arg;
fd = BIO_get_fd(rpc->TlsOut->bio, NULL); fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd); ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
nCount = 0; nCount = 0;
events[nCount++] = rpc->client->StopEvent; events[nCount++] = rpc->client->StopEvent;
events[nCount++] = Queue_Event(rpc->client->SendQueue); events[nCount++] = Queue_Event(rpc->client->SendQueue);
@ -527,7 +502,7 @@ static void* rpc_client_thread(void* arg)
*/ */
if (rpc_client_on_read_event(rpc) < 0) if (rpc_client_on_read_event(rpc) < 0)
{ {
DEBUG_WARN( "%s: an error occured when treating first packet\n", __FUNCTION__); DEBUG_WARN("%s: an error occured when treating first packet\n", __FUNCTION__);
goto out; goto out;
} }
@ -555,11 +530,10 @@ static void* rpc_client_thread(void* arg)
out: out:
CloseHandle(ReadEvent); CloseHandle(ReadEvent);
return NULL; return NULL;
} }
static void rpc_pdu_free(RPC_PDU* pdu) static void rpc_pdu_free(RPC_PDU *pdu)
{ {
if (!pdu) if (!pdu)
return; return;
@ -568,77 +542,87 @@ static void rpc_pdu_free(RPC_PDU* pdu)
free(pdu); free(pdu);
} }
static void rpc_fragment_free(wStream* fragment) static void rpc_fragment_free(wStream *fragment)
{ {
Stream_Free(fragment, TRUE); Stream_Free(fragment, TRUE);
} }
int rpc_client_new(rdpRpc* rpc) int rpc_client_new(rdpRpc *rpc)
{ {
RpcClient* client = NULL; RpcClient *client = NULL;
client = (RpcClient *)calloc(1, sizeof(RpcClient)); client = (RpcClient *)calloc(1, sizeof(RpcClient));
rpc->client = client; rpc->client = client;
if (!client) if (!client)
return -1; return -1;
client->Thread = CreateThread(NULL, 0, client->Thread = CreateThread(NULL, 0,
(LPTHREAD_START_ROUTINE) rpc_client_thread, (LPTHREAD_START_ROUTINE) rpc_client_thread,
rpc, CREATE_SUSPENDED, NULL); rpc, CREATE_SUSPENDED, NULL);
if (!client->Thread) if (!client->Thread)
return -1; return -1;
client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->StopEvent) if (!client->StopEvent)
return -1; return -1;
client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL); client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (!client->PduSentEvent) if (!client->PduSentEvent)
return -1; return -1;
client->SendQueue = Queue_New(TRUE, -1, -1); client->SendQueue = Queue_New(TRUE, -1, -1);
if (!client->SendQueue) if (!client->SendQueue)
return -1; return -1;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->pdu = NULL; client->pdu = NULL;
client->ReceivePool = Queue_New(TRUE, -1, -1); client->ReceivePool = Queue_New(TRUE, -1, -1);
if (!client->ReceivePool) if (!client->ReceivePool)
return -1; return -1;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->ReceiveQueue = Queue_New(TRUE, -1, -1); client->ReceiveQueue = Queue_New(TRUE, -1, -1);
if (!client->ReceiveQueue) if (!client->ReceiveQueue)
return -1; return -1;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
client->RecvFrag = NULL; client->RecvFrag = NULL;
client->FragmentPool = Queue_New(TRUE, -1, -1); client->FragmentPool = Queue_New(TRUE, -1, -1);
if (!client->FragmentPool) if (!client->FragmentPool)
return -1; return -1;
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->FragmentQueue = Queue_New(TRUE, -1, -1); client->FragmentQueue = Queue_New(TRUE, -1, -1);
if (!client->FragmentQueue) if (!client->FragmentQueue)
return -1; return -1;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
client->ClientCallList = ArrayList_New(TRUE); client->ClientCallList = ArrayList_New(TRUE);
if (!client->ClientCallList) if (!client->ClientCallList)
return -1; return -1;
ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free; ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
return 0; return 0;
} }
int rpc_client_start(rdpRpc* rpc) int rpc_client_start(rdpRpc *rpc)
{ {
rpc->client->Thread = CreateThread(NULL, 0, rpc->client->Thread = CreateThread(NULL, 0,
(LPTHREAD_START_ROUTINE) rpc_client_thread, (LPTHREAD_START_ROUTINE) rpc_client_thread,
rpc, 0, NULL); rpc, 0, NULL);
return 0; return 0;
} }
int rpc_client_stop(rdpRpc* rpc) int rpc_client_stop(rdpRpc *rpc)
{ {
if (rpc->client->Thread) if (rpc->client->Thread)
{ {
@ -650,10 +634,9 @@ int rpc_client_stop(rdpRpc* rpc)
return rpc_client_free(rpc); return rpc_client_free(rpc);
} }
int rpc_client_free(rdpRpc* rpc) int rpc_client_free(rdpRpc *rpc)
{ {
RpcClient* client; RpcClient *client;
client = rpc->client; client = rpc->client;
if (!client) if (!client)
@ -667,6 +650,7 @@ int rpc_client_free(rdpRpc* rpc)
if (client->FragmentPool) if (client->FragmentPool)
Queue_Free(client->FragmentPool); Queue_Free(client->FragmentPool);
if (client->FragmentQueue) if (client->FragmentQueue)
Queue_Free(client->FragmentQueue); Queue_Free(client->FragmentQueue);
@ -675,6 +659,7 @@ int rpc_client_free(rdpRpc* rpc)
if (client->ReceivePool) if (client->ReceivePool)
Queue_Free(client->ReceivePool); Queue_Free(client->ReceivePool);
if (client->ReceiveQueue) if (client->ReceiveQueue)
Queue_Free(client->ReceiveQueue); Queue_Free(client->ReceiveQueue);
@ -683,6 +668,7 @@ int rpc_client_free(rdpRpc* rpc)
if (client->StopEvent) if (client->StopEvent)
CloseHandle(client->StopEvent); CloseHandle(client->StopEvent);
if (client->PduSentEvent) if (client->PduSentEvent)
CloseHandle(client->PduSentEvent); CloseHandle(client->PduSentEvent);

File diff suppressed because it is too large Load Diff