wfreerdp-server: add PeerLogon callback for server logon

This commit is contained in:
Marc-André Moreau 2012-09-16 15:30:11 -04:00
parent efe82e6ede
commit a11615aebd
11 changed files with 76 additions and 37 deletions

View File

@ -26,6 +26,8 @@
#include <freerdp/input.h> #include <freerdp/input.h>
#include <freerdp/update.h> #include <freerdp/update.h>
#include <winpr/sspi.h>
typedef void (*psPeerContextNew)(freerdp_peer* client, rdpContext* context); typedef void (*psPeerContextNew)(freerdp_peer* client, rdpContext* context);
typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context); typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context);
@ -37,6 +39,7 @@ typedef void (*psPeerDisconnect)(freerdp_peer* client);
typedef boolean (*psPeerCapabilities)(freerdp_peer* client); typedef boolean (*psPeerCapabilities)(freerdp_peer* client);
typedef boolean (*psPeerPostConnect)(freerdp_peer* client); typedef boolean (*psPeerPostConnect)(freerdp_peer* client);
typedef boolean (*psPeerActivate)(freerdp_peer* client); typedef boolean (*psPeerActivate)(freerdp_peer* client);
typedef boolean (*psPeerLogon)(freerdp_peer* client, SEC_WINNT_AUTH_IDENTITY* identity, boolean automatic);
typedef int (*psPeerSendChannelData)(freerdp_peer* client, int channelId, uint8* data, int size); typedef int (*psPeerSendChannelData)(freerdp_peer* client, int channelId, uint8* data, int size);
typedef int (*psPeerReceiveChannelData)(freerdp_peer* client, int channelId, uint8* data, int size, int flags, int total_size); typedef int (*psPeerReceiveChannelData)(freerdp_peer* client, int channelId, uint8* data, int size, int flags, int total_size);
@ -64,13 +67,17 @@ struct rdp_freerdp_peer
psPeerCapabilities Capabilities; psPeerCapabilities Capabilities;
psPeerPostConnect PostConnect; psPeerPostConnect PostConnect;
psPeerActivate Activate; psPeerActivate Activate;
psPeerLogon Logon;
psPeerSendChannelData SendChannelData; psPeerSendChannelData SendChannelData;
psPeerReceiveChannelData ReceiveChannelData; psPeerReceiveChannelData ReceiveChannelData;
uint32 ack_frame_id; uint32 ack_frame_id;
boolean local; boolean local;
boolean connected;
boolean activated; boolean activated;
boolean authenticated;
SEC_WINNT_AUTH_IDENTITY identity;
}; };
FREERDP_API void freerdp_peer_context_new(freerdp_peer* client); FREERDP_API void freerdp_peer_context_new(freerdp_peer* client);

View File

@ -102,25 +102,25 @@ static boolean peer_recv_data_pdu(freerdp_peer* client, STREAM* s)
if (!rdp_server_accept_client_font_list_pdu(client->context->rdp, s)) if (!rdp_server_accept_client_font_list_pdu(client->context->rdp, s))
return false; return false;
if (client->PostConnect) if (!client->connected)
{ {
if (!client->PostConnect(client))
return false;
/** /**
* PostConnect should only be called once and should not be called * PostConnect should only be called once and should not be called
* after a reactivation sequence. * after a reactivation sequence.
*/ */
client->PostConnect = NULL;
}
if (client->Activate) IFCALLRET(client->PostConnect, client->connected, client);
{
/* Activate will be called everytime after the client is activated/reactivated. */ if (!client->connected)
if (!client->Activate(client))
return false; return false;
} }
client->activated = true; /* Activate will be called everytime after the client is activated/reactivated. */
IFCALLRET(client->Activate, client->activated, client);
if (!client->activated)
return false;
break; break;
@ -244,55 +244,69 @@ static boolean peer_recv_pdu(freerdp_peer* client, STREAM* s)
static boolean peer_recv_callback(rdpTransport* transport, STREAM* s, void* extra) static boolean peer_recv_callback(rdpTransport* transport, STREAM* s, void* extra)
{ {
freerdp_peer* client = (freerdp_peer*) extra; freerdp_peer* client = (freerdp_peer*) extra;
rdpRdp* rdp = client->context->rdp;
switch (client->context->rdp->state) switch (rdp->state)
{ {
case CONNECTION_STATE_INITIAL: case CONNECTION_STATE_INITIAL:
if (!rdp_server_accept_nego(client->context->rdp, s)) if (!rdp_server_accept_nego(rdp, s))
return false; return false;
if (rdp->nego->selected_protocol & PROTOCOL_NLA)
{
sspi_CopyAuthIdentity(&client->identity, &(rdp->nego->transport->credssp->identity));
IFCALLRET(client->Logon, client->authenticated, client, &client->identity, true);
credssp_free(rdp->nego->transport->credssp);
}
else
{
IFCALLRET(client->Logon, client->authenticated, client, &client->identity, false);
}
break; break;
case CONNECTION_STATE_NEGO: case CONNECTION_STATE_NEGO:
if (!rdp_server_accept_mcs_connect_initial(client->context->rdp, s)) if (!rdp_server_accept_mcs_connect_initial(rdp, s))
return false; return false;
break; break;
case CONNECTION_STATE_MCS_CONNECT: case CONNECTION_STATE_MCS_CONNECT:
if (!rdp_server_accept_mcs_erect_domain_request(client->context->rdp, s)) if (!rdp_server_accept_mcs_erect_domain_request(rdp, s))
return false; return false;
break; break;
case CONNECTION_STATE_MCS_ERECT_DOMAIN: case CONNECTION_STATE_MCS_ERECT_DOMAIN:
if (!rdp_server_accept_mcs_attach_user_request(client->context->rdp, s)) if (!rdp_server_accept_mcs_attach_user_request(rdp, s))
return false; return false;
break; break;
case CONNECTION_STATE_MCS_ATTACH_USER: case CONNECTION_STATE_MCS_ATTACH_USER:
if (!rdp_server_accept_mcs_channel_join_request(client->context->rdp, s)) if (!rdp_server_accept_mcs_channel_join_request(rdp, s))
return false; return false;
break; break;
case CONNECTION_STATE_MCS_CHANNEL_JOIN: case CONNECTION_STATE_MCS_CHANNEL_JOIN:
if (client->context->rdp->settings->encryption) { if (rdp->settings->encryption)
if (!rdp_server_accept_client_keys(client->context->rdp, s)) {
if (!rdp_server_accept_client_keys(rdp, s))
return false; return false;
break; break;
} }
client->context->rdp->state = CONNECTION_STATE_ESTABLISH_KEYS; rdp->state = CONNECTION_STATE_ESTABLISH_KEYS;
/* FALLTHROUGH */ /* FALLTHROUGH */
case CONNECTION_STATE_ESTABLISH_KEYS: case CONNECTION_STATE_ESTABLISH_KEYS:
if (!rdp_server_accept_client_info(client->context->rdp, s)) if (!rdp_server_accept_client_info(rdp, s))
return false; return false;
IFCALL(client->Capabilities, client); IFCALL(client->Capabilities, client);
if (!rdp_send_demand_active(client->context->rdp)) if (!rdp_send_demand_active(rdp))
return false; return false;
break; break;
case CONNECTION_STATE_LICENSE: case CONNECTION_STATE_LICENSE:
if (!rdp_server_accept_confirm_active(client->context->rdp, s)) if (!rdp_server_accept_confirm_active(rdp, s))
{ {
/** /**
* During reactivation sequence the client might sent some input or channel data * During reactivation sequence the client might sent some input or channel data
@ -309,7 +323,7 @@ static boolean peer_recv_callback(rdpTransport* transport, STREAM* s, void* extr
break; break;
default: default:
printf("Invalid state %d\n", client->context->rdp->state); printf("Invalid state %d\n", rdp->state);
return false; return false;
} }

View File

@ -263,7 +263,7 @@ boolean transport_accept_nla(rdpTransport* transport)
return false; return false;
} }
credssp_free(transport->credssp); /* don't free credssp module yet, we need to copy the credentials from it first */
return true; return true;
} }

View File

@ -35,6 +35,7 @@ typedef struct rdp_transport rdpTransport;
#include <winpr/sspi.h> #include <winpr/sspi.h>
#include <freerdp/crypto/tls.h> #include <freerdp/crypto/tls.h>
#include <freerdp/crypto/nla.h>
#include <time.h> #include <time.h>
#include <freerdp/types.h> #include <freerdp/types.h>

View File

@ -547,7 +547,7 @@ int credssp_server_authenticate(rdpCredssp* credssp)
return -1; return -1;
} }
//sspi_SecBufferFree(&credssp->negoToken); sspi_SecBufferFree(&credssp->negoToken);
credssp->negoToken.pvBuffer = NULL; credssp->negoToken.pvBuffer = NULL;
credssp->negoToken.cbBuffer = 0; credssp->negoToken.cbBuffer = 0;
@ -568,7 +568,7 @@ int credssp_server_authenticate(rdpCredssp* credssp)
#endif #endif
credssp_send(credssp); credssp_send(credssp);
//credssp_buffer_free(credssp); credssp_buffer_free(credssp);
if (status != SEC_I_CONTINUE_NEEDED) if (status != SEC_I_CONTINUE_NEEDED)
break; break;
@ -581,7 +581,11 @@ int credssp_server_authenticate(rdpCredssp* credssp)
if (credssp_recv(credssp) < 0) if (credssp_recv(credssp) < 0)
return -1; return -1;
credssp_decrypt_ts_credentials(credssp); if (credssp_decrypt_ts_credentials(credssp) != SEC_E_OK)
{
printf("Could not decrypt TSCredentials status: 0x%08X\n", status);
return 0;
}
if (status != SEC_E_OK) if (status != SEC_E_OK)
{ {
@ -833,6 +837,8 @@ void credssp_read_ts_password_creds(rdpCredssp* credssp, STREAM* s)
CopyMemory(credssp->identity.Password, s->p, credssp->identity.PasswordLength); CopyMemory(credssp->identity.Password, s->p, credssp->identity.PasswordLength);
stream_seek(s, credssp->identity.PasswordLength); stream_seek(s, credssp->identity.PasswordLength);
credssp->identity.PasswordLength /= 2; credssp->identity.PasswordLength /= 2;
credssp->identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
} }
void credssp_write_ts_password_creds(rdpCredssp* credssp, STREAM* s) void credssp_write_ts_password_creds(rdpCredssp* credssp, STREAM* s)
@ -1300,6 +1306,6 @@ void credssp_free(rdpCredssp* credssp)
free(credssp->identity.User); free(credssp->identity.User);
free(credssp->identity.Domain); free(credssp->identity.Domain);
free(credssp->identity.Password); free(credssp->identity.Password);
//free(credssp); free(credssp);
} }
} }

View File

@ -57,11 +57,7 @@ void freerdp_thread_start(freerdp_thread* thread, void* func, void* arg)
#ifdef _WIN32 #ifdef _WIN32
{ {
# ifdef _MSC_VER
CloseHandle((HANDLE)_beginthreadex(NULL, 0, func, arg, 0, NULL));
#else
CloseHandle(CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL)); CloseHandle(CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL));
#endif
} }
#else #else
{ {

View File

@ -41,10 +41,14 @@ int main(int argc, char* argv[])
if (argc == 2) if (argc == 2)
server->port = (DWORD) atoi(argv[1]); server->port = (DWORD) atoi(argv[1]);
printf("Starting server\n");
wfreerdp_server_start(server); wfreerdp_server_start(server);
WaitForSingleObject(server->thread, INFINITE); WaitForSingleObject(server->thread, INFINITE);
printf("Stopping server\n");
wfreerdp_server_stop(server); wfreerdp_server_stop(server);
wfreerdp_server_free(server); wfreerdp_server_free(server);

View File

@ -95,9 +95,10 @@ BOOL wfreerdp_server_start(wfServer* server)
if (instance->Open(instance, NULL, (uint16) server->port)) if (instance->Open(instance, NULL, (uint16) server->port))
{ {
server->thread = CreateThread(NULL, 0, wf_server_main_loop, (void*) instance, 0, NULL); server->thread = CreateThread(NULL, 0, wf_server_main_loop, (void*) instance, 0, NULL);
return TRUE;
} }
return TRUE; return FALSE;
} }
BOOL wfreerdp_server_stop(wfServer* server) BOOL wfreerdp_server_stop(wfServer* server)

View File

@ -40,7 +40,6 @@ struct wf_info
int bitsPerPixel; int bitsPerPixel;
HDC driverDC; HDC driverDC;
int peerCount; int peerCount;
BOOL activated;
void* changeBuffer; void* changeBuffer;
int framesPerSecond; int framesPerSecond;
LPTSTR deviceKey; LPTSTR deviceKey;

View File

@ -93,9 +93,16 @@ boolean wf_peer_post_connect(freerdp_peer* client)
boolean wf_peer_activate(freerdp_peer* client) boolean wf_peer_activate(freerdp_peer* client)
{ {
wfPeerContext* context = (wfPeerContext*) client->context; return true;
}
context->activated = true; boolean wf_peer_logon(freerdp_peer* client, SEC_WINNT_AUTH_IDENTITY* identity, boolean automatic)
{
if (automatic)
{
_tprintf(_T("Logon: User:%s Domain:%s Password:%s\n"),
identity->User, identity->Domain, identity->Password);
}
return true; return true;
} }
@ -190,6 +197,7 @@ DWORD WINAPI wf_peer_main_loop(LPVOID lpParam)
client->PostConnect = wf_peer_post_connect; client->PostConnect = wf_peer_post_connect;
client->Activate = wf_peer_activate; client->Activate = wf_peer_activate;
client->Logon = wf_peer_logon;
client->input->SynchronizeEvent = wf_peer_synchronize_event; client->input->SynchronizeEvent = wf_peer_synchronize_event;
client->input->KeyboardEvent = wf_peer_keyboard_event; client->input->KeyboardEvent = wf_peer_keyboard_event;

View File

@ -348,6 +348,7 @@ void sspi_CopyAuthIdentity(SEC_WINNT_AUTH_IDENTITY* identity, SEC_WINNT_AUTH_IDE
{ {
identity->User = (UINT16*) malloc((identity->UserLength + 1) * sizeof(WCHAR)); identity->User = (UINT16*) malloc((identity->UserLength + 1) * sizeof(WCHAR));
CopyMemory(identity->User, srcIdentity->User, identity->UserLength * sizeof(WCHAR)); CopyMemory(identity->User, srcIdentity->User, identity->UserLength * sizeof(WCHAR));
identity->User[identity->UserLength] = 0;
} }
identity->DomainLength = srcIdentity->DomainLength; identity->DomainLength = srcIdentity->DomainLength;
@ -356,6 +357,7 @@ void sspi_CopyAuthIdentity(SEC_WINNT_AUTH_IDENTITY* identity, SEC_WINNT_AUTH_IDE
{ {
identity->Domain = (UINT16*) malloc((identity->DomainLength + 1) * sizeof(WCHAR)); identity->Domain = (UINT16*) malloc((identity->DomainLength + 1) * sizeof(WCHAR));
CopyMemory(identity->Domain, srcIdentity->Domain, identity->DomainLength * sizeof(WCHAR)); CopyMemory(identity->Domain, srcIdentity->Domain, identity->DomainLength * sizeof(WCHAR));
identity->Domain[identity->DomainLength] = 0;
} }
identity->PasswordLength = srcIdentity->PasswordLength; identity->PasswordLength = srcIdentity->PasswordLength;
@ -364,6 +366,7 @@ void sspi_CopyAuthIdentity(SEC_WINNT_AUTH_IDENTITY* identity, SEC_WINNT_AUTH_IDE
{ {
identity->Password = (UINT16*) malloc((identity->PasswordLength + 1) * sizeof(WCHAR)); identity->Password = (UINT16*) malloc((identity->PasswordLength + 1) * sizeof(WCHAR));
CopyMemory(identity->Password, srcIdentity->Password, identity->PasswordLength * sizeof(WCHAR)); CopyMemory(identity->Password, srcIdentity->Password, identity->PasswordLength * sizeof(WCHAR));
identity->Password[identity->PasswordLength] = 0;
} }
} }