diff --git a/channels/sshagent/client/sshagent_main.c b/channels/sshagent/client/sshagent_main.c index 1c95210c7..c0f915be9 100644 --- a/channels/sshagent/client/sshagent_main.c +++ b/channels/sshagent/client/sshagent_main.c @@ -67,7 +67,7 @@ struct _SSHAGENT_LISTENER_CALLBACK IWTSVirtualChannelManager* channel_mgr; rdpContext* rdpcontext; - const char *agent_uds_path; + const char* agent_uds_path; }; typedef struct _SSHAGENT_CHANNEL_CALLBACK SSHAGENT_CHANNEL_CALLBACK; @@ -80,7 +80,7 @@ struct _SSHAGENT_CHANNEL_CALLBACK IWTSVirtualChannel* channel; rdpContext* rdpcontext; - int agent_fd; + int agent_fd; HANDLE thread; CRITICAL_SECTION lock; }; @@ -101,29 +101,35 @@ struct _SSHAGENT_PLUGIN * * @return The fd on success, otherwise -1 */ -static int connect_to_sshagent(const char *udspath) +static int connect_to_sshagent(const char* udspath) { - int agent_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (agent_fd == -1) - { - WLog_ERR(TAG, "Can't open Unix domain socket!"); - return -1; - } + int agent_fd = socket(AF_UNIX, SOCK_STREAM, 0); - struct sockaddr_un addr; - memset(&addr, 0, sizeof(addr)); - addr.sun_family = AF_UNIX; - strncpy(addr.sun_path, udspath, sizeof(addr.sun_path) - 1); - int rc = connect(agent_fd, (struct sockaddr*)&addr, sizeof(addr)); - if (rc != 0) - { - WLog_ERR(TAG, "Can't connect to Unix domain socket \"%s\"!", - udspath); - close(agent_fd); - return -1; - } + if (agent_fd == -1) + { + WLog_ERR(TAG, "Can't open Unix domain socket!"); + return -1; + } - return agent_fd; + struct sockaddr_un addr; + + memset(&addr, 0, sizeof(addr)); + + addr.sun_family = AF_UNIX; + + strncpy(addr.sun_path, udspath, sizeof(addr.sun_path) - 1); + + int rc = connect(agent_fd, (struct sockaddr*)&addr, sizeof(addr)); + + if (rc != 0) + { + WLog_ERR(TAG, "Can't connect to Unix domain socket \"%s\"!", + udspath); + close(agent_fd); + return -1; + } + + return agent_fd; } @@ -133,57 +139,58 @@ static int connect_to_sshagent(const char *udspath) * * @return NULL */ -static void *sshagent_read_thread(void *data) +static void* sshagent_read_thread(void* data) { - SSHAGENT_CHANNEL_CALLBACK *callback = (SSHAGENT_CHANNEL_CALLBACK *)data; - BYTE buffer[4096]; - int going = 1; - UINT status = CHANNEL_RC_OK; + SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)data; + BYTE buffer[4096]; + int going = 1; + UINT status = CHANNEL_RC_OK; - while (going) - { - int bytes_read = read(callback->agent_fd, - buffer, - sizeof(buffer)); + while (going) + { + int bytes_read = read(callback->agent_fd, + buffer, + sizeof(buffer)); - if (bytes_read == 0) - { - /* Socket closed cleanly at other end */ - going = 0; - } - else if (bytes_read < 0) - { - if (errno != EINTR) - { - WLog_ERR(TAG, - "Error reading from sshagent, errno=%d", - errno); - status = ERROR_READ_FAULT; - going = 0; - } - } - else - { - /* Something read: forward to virtual channel */ - status = callback->channel->Write(callback->channel, - bytes_read, - buffer, - NULL); - if (status != CHANNEL_RC_OK) - { - going = 0; - } - } - } + if (bytes_read == 0) + { + /* Socket closed cleanly at other end */ + going = 0; + } + else if (bytes_read < 0) + { + if (errno != EINTR) + { + WLog_ERR(TAG, + "Error reading from sshagent, errno=%d", + errno); + status = ERROR_READ_FAULT; + going = 0; + } + } + else + { + /* Something read: forward to virtual channel */ + status = callback->channel->Write(callback->channel, + bytes_read, + buffer, + NULL); - close(callback->agent_fd); + if (status != CHANNEL_RC_OK) + { + going = 0; + } + } + } - if (status != CHANNEL_RC_OK) - setChannelError(callback->rdpcontext, status, - "sshagent_read_thread reported an error"); + close(callback->agent_fd); + + if (status != CHANNEL_RC_OK) + setChannelError(callback->rdpcontext, status, + "sshagent_read_thread reported an error"); ExitThread(0); - return NULL; + return NULL; } /** @@ -191,41 +198,41 @@ static void *sshagent_read_thread(void *data) * * @return 0 on success, otherwise a Win32 error code */ -static UINT sshagent_on_data_received(IWTSVirtualChannelCallback* pChannelCallback, wStream *data) +static UINT sshagent_on_data_received(IWTSVirtualChannelCallback* pChannelCallback, wStream* data) { SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*) pChannelCallback; BYTE* pBuffer = Stream_Pointer(data); UINT32 cbSize = Stream_GetRemainingLength(data); - BYTE *pos = pBuffer; + BYTE* pos = pBuffer; + /* Forward what we have received to the ssh agent */ + UINT32 bytes_to_write = cbSize; + errno = 0; - /* Forward what we have received to the ssh agent */ - UINT32 bytes_to_write = cbSize; - errno = 0; - while (bytes_to_write > 0) - { - int bytes_written = write(callback->agent_fd, pos, - bytes_to_write); - if (bytes_written < 0) - { - if (errno != EINTR) - { - WLog_ERR(TAG, - "Error writing to sshagent, errno=%d", - errno); - return ERROR_WRITE_FAULT; - } - } - else - { - bytes_to_write -= bytes_written; - pos += bytes_written; - } - } + while (bytes_to_write > 0) + { + int bytes_written = write(callback->agent_fd, pos, + bytes_to_write); - /* Consume stream */ - Stream_Seek(data, cbSize); + if (bytes_written < 0) + { + if (errno != EINTR) + { + WLog_ERR(TAG, + "Error writing to sshagent, errno=%d", + errno); + return ERROR_WRITE_FAULT; + } + } + else + { + bytes_to_write -= bytes_written; + pos += bytes_written; + } + } - return CHANNEL_RC_OK; + /* Consume stream */ + Stream_Seek(data, cbSize); + return CHANNEL_RC_OK; } /** @@ -236,11 +243,10 @@ static UINT sshagent_on_data_received(IWTSVirtualChannelCallback* pChannelCallba static UINT sshagent_on_close(IWTSVirtualChannelCallback* pChannelCallback) { SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*) pChannelCallback; - - /* Call shutdown() to wake up the read() in sshagent_read_thread(). */ - shutdown(callback->agent_fd, SHUT_RDWR); - + /* Call shutdown() to wake up the read() in sshagent_read_thread(). */ + shutdown(callback->agent_fd, SHUT_RDWR); EnterCriticalSection(&callback->lock); + if (WaitForSingleObject(callback->thread, INFINITE) == WAIT_FAILED) { UINT error = GetLastError(); @@ -248,11 +254,10 @@ static UINT sshagent_on_close(IWTSVirtualChannelCallback* pChannelCallback) return error; } - CloseHandle(callback->thread); + CloseHandle(callback->thread); + LeaveCriticalSection(&callback->lock); DeleteCriticalSection(&callback->lock); - free(callback); - return CHANNEL_RC_OK; } @@ -263,12 +268,11 @@ static UINT sshagent_on_close(IWTSVirtualChannelCallback* pChannelCallback) * @return 0 on success, otherwise a Win32 error code */ static UINT sshagent_on_new_channel_connection(IWTSListenerCallback* pListenerCallback, - IWTSVirtualChannel* pChannel, BYTE* Data, BOOL* pbAccept, - IWTSVirtualChannelCallback** ppCallback) + IWTSVirtualChannel* pChannel, BYTE* Data, BOOL* pbAccept, + IWTSVirtualChannelCallback** ppCallback) { SSHAGENT_CHANNEL_CALLBACK* callback; SSHAGENT_LISTENER_CALLBACK* listener_callback = (SSHAGENT_LISTENER_CALLBACK*) pListenerCallback; - callback = (SSHAGENT_CHANNEL_CALLBACK*) calloc(1, sizeof(SSHAGENT_CHANNEL_CALLBACK)); if (!callback) @@ -277,39 +281,41 @@ static UINT sshagent_on_new_channel_connection(IWTSListenerCallback* pListenerCa return CHANNEL_RC_NO_MEMORY; } - /* Now open a connection to the local ssh-agent. Do this for each - * connection to the plugin in case we mess up the agent session. */ - callback->agent_fd - = connect_to_sshagent(listener_callback->agent_uds_path); - if (callback->agent_fd == -1) - { - return CHANNEL_RC_INITIALIZATION_ERROR; - } + /* Now open a connection to the local ssh-agent. Do this for each + * connection to the plugin in case we mess up the agent session. */ + callback->agent_fd + = connect_to_sshagent(listener_callback->agent_uds_path); + + if (callback->agent_fd == -1) + { + free(callback); + return CHANNEL_RC_INITIALIZATION_ERROR; + } InitializeCriticalSection(&callback->lock); - callback->iface.OnDataReceived = sshagent_on_data_received; callback->iface.OnClose = sshagent_on_close; callback->plugin = listener_callback->plugin; callback->channel_mgr = listener_callback->channel_mgr; callback->channel = pChannel; callback->rdpcontext = listener_callback->rdpcontext; - callback->thread - = CreateThread(NULL, - 0, - (LPTHREAD_START_ROUTINE) sshagent_read_thread, - (void*) callback, - 0, - NULL); + = CreateThread(NULL, + 0, + (LPTHREAD_START_ROUTINE) sshagent_read_thread, + (void*) callback, + 0, + NULL); + if (!callback->thread) { WLog_ERR(TAG, "CreateThread failed!"); - return CHANNEL_RC_INITIALIZATION_ERROR; + DeleteCriticalSection(&callback->lock); + free(callback); + return CHANNEL_RC_INITIALIZATION_ERROR; } *ppCallback = (IWTSVirtualChannelCallback*) callback; - return CHANNEL_RC_OK; } @@ -321,8 +327,8 @@ static UINT sshagent_on_new_channel_connection(IWTSListenerCallback* pListenerCa static UINT sshagent_plugin_initialize(IWTSPlugin* pPlugin, IWTSVirtualChannelManager* pChannelMgr) { SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*) pPlugin; - - sshagent->listener_callback = (SSHAGENT_LISTENER_CALLBACK*) calloc(1, sizeof(SSHAGENT_LISTENER_CALLBACK)); + sshagent->listener_callback = (SSHAGENT_LISTENER_CALLBACK*) calloc(1, + sizeof(SSHAGENT_LISTENER_CALLBACK)); if (!sshagent->listener_callback) { @@ -334,16 +340,18 @@ static UINT sshagent_plugin_initialize(IWTSPlugin* pPlugin, IWTSVirtualChannelMa sshagent->listener_callback->iface.OnNewChannelConnection = sshagent_on_new_channel_connection; sshagent->listener_callback->plugin = pPlugin; sshagent->listener_callback->channel_mgr = pChannelMgr; + sshagent->listener_callback->agent_uds_path = getenv("SSH_AUTH_SOCK"); - sshagent->listener_callback->agent_uds_path = getenv("SSH_AUTH_SOCK"); - if (sshagent->listener_callback->agent_uds_path == NULL) - { + if (sshagent->listener_callback->agent_uds_path == NULL) + { WLog_ERR(TAG, "Environment variable $SSH_AUTH_SOCK undefined!"); - return CHANNEL_RC_INITIALIZATION_ERROR; - } + free(sshagent->listener_callback); + sshagent->listener_callback = NULL; + return CHANNEL_RC_INITIALIZATION_ERROR; + } return pChannelMgr->CreateListener(pChannelMgr, "SSHAGENT", 0, - (IWTSListenerCallback*) sshagent->listener_callback, NULL); + (IWTSListenerCallback*) sshagent->listener_callback, NULL); } /** @@ -354,9 +362,7 @@ static UINT sshagent_plugin_initialize(IWTSPlugin* pPlugin, IWTSVirtualChannelMa static UINT sshagent_plugin_terminated(IWTSPlugin* pPlugin) { SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*) pPlugin; - free(sshagent); - return CHANNEL_RC_OK; } @@ -375,7 +381,6 @@ UINT DVCPluginEntry(IDRDYNVC_ENTRY_POINTS* pEntryPoints) { UINT status = CHANNEL_RC_OK; SSHAGENT_PLUGIN* sshagent; - sshagent = (SSHAGENT_PLUGIN*) pEntryPoints->GetPlugin(pEntryPoints, "sshagent"); if (!sshagent) @@ -392,9 +397,8 @@ UINT DVCPluginEntry(IDRDYNVC_ENTRY_POINTS* pEntryPoints) sshagent->iface.Connected = NULL; sshagent->iface.Disconnected = NULL; sshagent->iface.Terminated = sshagent_plugin_terminated; - sshagent->rdpcontext = ((freerdp*)((rdpSettings*) pEntryPoints->GetRdpSettings( - pEntryPoints))->instance)->context; - + sshagent->rdpcontext = ((freerdp*)((rdpSettings*) pEntryPoints->GetRdpSettings( + pEntryPoints))->instance)->context; status = pEntryPoints->RegisterPlugin(pEntryPoints, "sshagent", (IWTSPlugin*) sshagent); } diff --git a/channels/sshagent/server/sshagent_main.c b/channels/sshagent/server/sshagent_main.c index 2cf3a9108..222b9caef 100644 --- a/channels/sshagent/server/sshagent_main.c +++ b/channels/sshagent/server/sshagent_main.c @@ -19,7 +19,7 @@ * limitations under the License. */ -/* +/* * Portions are from OpenSSH, under the following license: * * Author: Tatu Ylonen @@ -60,7 +60,7 @@ * xrdp-ssh-agent.c: program to forward ssh-agent protocol from xrdp session * * This performs the equivalent function of ssh-agent on a server you connect - * to via ssh, but the ssh-agent protocol is over an RDP dynamic virtual + * to via ssh, but the ssh-agent protocol is over an RDP dynamic virtual * channel and not an SSH channel. * * This will print out variables to set in your environment (specifically, @@ -120,283 +120,304 @@ static int is_going = 1; /* Make a template filename for mk[sd]temp() */ /* This is from mktemp_proto() in misc.c from openssh */ void -mktemp_proto(char *s, size_t len) +mktemp_proto(char* s, size_t len) { - const char *tmpdir; + const char* tmpdir; int r; - if ((tmpdir = getenv("TMPDIR")) != NULL) { + if ((tmpdir = getenv("TMPDIR")) != NULL) + { r = snprintf(s, len, "%s/ssh-XXXXXXXXXXXX", tmpdir); + if (r > 0 && (size_t)r < len) return; } + r = snprintf(s, len, "/tmp/ssh-XXXXXXXXXXXX"); + if (r < 0 || (size_t)r >= len) - { - fprintf(stderr, "%s: template string too short", __func__); - exit(1); - } + { + fprintf(stderr, "%s: template string too short", __func__); + exit(1); + } } /* This uses parts of main() in ssh-agent.c from openssh */ static void -setup_ssh_agent(struct sockaddr_un *addr) +setup_ssh_agent(struct sockaddr_un* addr) { - int rc; + int rc; + /* Create private directory for agent socket */ + mktemp_proto(socket_dir, sizeof(socket_dir)); - /* Create private directory for agent socket */ - mktemp_proto(socket_dir, sizeof(socket_dir)); - if (mkdtemp(socket_dir) == NULL) { - perror("mkdtemp: private socket dir"); - exit(1); - } - snprintf(socket_name, sizeof socket_name, "%s/agent.%ld", socket_dir, - (long)getpid()); + if (mkdtemp(socket_dir) == NULL) + { + perror("mkdtemp: private socket dir"); + exit(1); + } - /* Create unix domain socket */ - unlink(socket_name); + snprintf(socket_name, sizeof socket_name, "%s/agent.%ld", socket_dir, + (long)getpid()); + /* Create unix domain socket */ + unlink(socket_name); + sa_uds_fd = socket(AF_UNIX, SOCK_STREAM, 0); - sa_uds_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (sa_uds_fd == -1) - { - fprintf(stderr, "sshagent: socket creation failed"); - exit(2); - } + if (sa_uds_fd == -1) + { + fprintf(stderr, "sshagent: socket creation failed"); + exit(2); + } - memset(addr, 0, sizeof(struct sockaddr_un)); - addr->sun_family = AF_UNIX; - strncpy(addr->sun_path, socket_name, sizeof(addr->sun_path)); - addr->sun_path[sizeof(addr->sun_path) - 1] = 0; + memset(addr, 0, sizeof(struct sockaddr_un)); + addr->sun_family = AF_UNIX; + strncpy(addr->sun_path, socket_name, sizeof(addr->sun_path)); + addr->sun_path[sizeof(addr->sun_path) - 1] = 0; + /* Create with privileges rw------- so other users can't access the UDS */ + mode_t umask_sav = umask(0177); + rc = bind(sa_uds_fd, (struct sockaddr*)addr, sizeof(struct sockaddr_un)); - /* Create with privileges rw------- so other users can't access the UDS */ - mode_t umask_sav = umask(0177); - rc = bind(sa_uds_fd, (struct sockaddr *)addr, sizeof(struct sockaddr_un)); - if (rc != 0) - { - fprintf(stderr, "sshagent: bind failed"); - close(sa_uds_fd); - unlink(socket_name); - exit(3); - } - umask(umask_sav); + if (rc != 0) + { + fprintf(stderr, "sshagent: bind failed"); + close(sa_uds_fd); + unlink(socket_name); + exit(3); + } - rc = listen(sa_uds_fd, /* backlog = */ 5); - if (rc != 0) - { - fprintf(stderr, "listen failed\n"); - close(sa_uds_fd); - unlink(socket_name); - exit(1); - } + umask(umask_sav); + rc = listen(sa_uds_fd, /* backlog = */ 5); - /* Now fork: the child becomes the ssh-agent daemon and the parent prints - * out the pid and socket name. */ - pid_t pid = fork(); - if (pid == -1) - { - perror("fork"); - exit(1); - } - else if (pid != 0) - { - /* Parent */ - close(sa_uds_fd); - printf("SSH_AUTH_SOCK=%s; export SSH_AUTH_SOCK;\n", socket_name); - printf("SSH_AGENT_PID=%d; export SSH_AGENT_PID;\n", pid); - printf("echo Agent pid %d;\n", pid); - exit(0); - } + if (rc != 0) + { + fprintf(stderr, "listen failed\n"); + close(sa_uds_fd); + unlink(socket_name); + exit(1); + } - /* Child */ + /* Now fork: the child becomes the ssh-agent daemon and the parent prints + * out the pid and socket name. */ + pid_t pid = fork(); - if (setsid() == -1) - { - fprintf(stderr, "setsid failed"); - exit(1); - } + if (pid == -1) + { + perror("fork"); + exit(1); + } + else if (pid != 0) + { + /* Parent */ + close(sa_uds_fd); + printf("SSH_AUTH_SOCK=%s; export SSH_AUTH_SOCK;\n", socket_name); + printf("SSH_AGENT_PID=%d; export SSH_AGENT_PID;\n", pid); + printf("echo Agent pid %d;\n", pid); + exit(0); + } - (void)chdir("/"); - int devnullfd; - if ((devnullfd = open(_PATH_DEVNULL, O_RDWR, 0)) != -1) { - /* XXX might close listen socket */ - (void)dup2(devnullfd, STDIN_FILENO); - (void)dup2(devnullfd, STDOUT_FILENO); - (void)dup2(devnullfd, STDERR_FILENO); - if (devnullfd > 2) - close(devnullfd); - } + /* Child */ - /* deny core dumps, since memory contains unencrypted private keys */ - struct rlimit rlim; - rlim.rlim_cur = rlim.rlim_max = 0; - if (setrlimit(RLIMIT_CORE, &rlim) < 0) { - fprintf(stderr, "setrlimit RLIMIT_CORE: %s", strerror(errno)); - exit(1); - } + if (setsid() == -1) + { + fprintf(stderr, "setsid failed"); + exit(1); + } + + (void)chdir("/"); + int devnullfd; + + if ((devnullfd = open(_PATH_DEVNULL, O_RDWR, 0)) != -1) + { + /* XXX might close listen socket */ + (void)dup2(devnullfd, STDIN_FILENO); + (void)dup2(devnullfd, STDOUT_FILENO); + (void)dup2(devnullfd, STDERR_FILENO); + + if (devnullfd > 2) + close(devnullfd); + } + + /* deny core dumps, since memory contains unencrypted private keys */ + struct rlimit rlim; + rlim.rlim_cur = rlim.rlim_max = 0; + + if (setrlimit(RLIMIT_CORE, &rlim) < 0) + { + fprintf(stderr, "setrlimit RLIMIT_CORE: %s", strerror(errno)); + exit(1); + } } static void handle_connection(int client_fd) { - int rdp_fd = -1; - int rc; - void *channel = WTSVirtualChannelOpenEx(WTS_CURRENT_SESSION, - "SSHAGENT", - WTS_CHANNEL_OPTION_DYNAMIC_PRI_MED); - if (channel == NULL) - { - fprintf(stderr, "WTSVirtualChannelOpenEx() failed\n"); - } + int rdp_fd = -1; + int rc; + void* channel = WTSVirtualChannelOpenEx(WTS_CURRENT_SESSION, + "SSHAGENT", + WTS_CHANNEL_OPTION_DYNAMIC_PRI_MED); - unsigned int retlen; - int *retdata; - rc = WTSVirtualChannelQuery(channel, - WTSVirtualFileHandle, - (void **)&retdata, - &retlen); - if (!rc) - { - fprintf(stderr, "WTSVirtualChannelQuery() failed\n"); - } - if (retlen != sizeof(rdp_fd)) - { - fprintf(stderr, "WTSVirtualChannelQuery() returned wrong length %d\n", - retlen); - } - rdp_fd = *retdata; + if (channel == NULL) + { + fprintf(stderr, "WTSVirtualChannelOpenEx() failed\n"); + } - int client_going = 1; - while (client_going) - { - /* Wait for data from RDP or the client */ - fd_set readfds; - FD_ZERO(&readfds); - FD_SET(client_fd, &readfds); - FD_SET(rdp_fd, &readfds); - select(FD_SETSIZE, &readfds, NULL, NULL, NULL); + unsigned int retlen; + int* retdata; + rc = WTSVirtualChannelQuery(channel, + WTSVirtualFileHandle, + (void**)&retdata, + &retlen); - if (FD_ISSET(rdp_fd, &readfds)) - { - /* Read from RDP and write to the client */ - char buffer[4096]; - unsigned int bytes_to_write; - rc = WTSVirtualChannelRead(channel, - /* TimeOut = */ 5000, - buffer, - sizeof(buffer), - &bytes_to_write); - if (rc == 1) - { - char *pos = buffer; - errno = 0; - while (bytes_to_write > 0) - { - int bytes_written = send(client_fd, pos, bytes_to_write, 0); + if (!rc) + { + fprintf(stderr, "WTSVirtualChannelQuery() failed\n"); + } - if (bytes_written > 0) - { - bytes_to_write -= bytes_written; - pos += bytes_written; - } - else if (bytes_written == 0) - { - fprintf(stderr, "send() returned 0!\n"); - } - else if (errno != EINTR) - { - /* Error */ - fprintf(stderr, "Error %d on recv\n", errno); - client_going = 0; - } - } - } - else - { - /* Error */ - fprintf(stderr, "WTSVirtualChannelRead() failed: %d\n", errno); - client_going = 0; - } - } + if (retlen != sizeof(rdp_fd)) + { + fprintf(stderr, "WTSVirtualChannelQuery() returned wrong length %d\n", + retlen); + } - if (FD_ISSET(client_fd, &readfds)) - { - /* Read from the client and write to RDP */ - char buffer[4096]; - ssize_t bytes_to_write = recv(client_fd, buffer, sizeof(buffer), 0); - if (bytes_to_write > 0) - { - char *pos = buffer; - while (bytes_to_write > 0) - { - unsigned int bytes_written; - int rc = WTSVirtualChannelWrite(channel, - pos, - bytes_to_write, - &bytes_written); - if (rc == 0) - { - fprintf(stderr, "WTSVirtualChannelWrite() failed: %d\n", - errno); - client_going = 0; - } - else - { - bytes_to_write -= bytes_written; - pos += bytes_written; - } - } - } - else if (bytes_to_write == 0) - { - /* Client has closed connection */ - client_going = 0; - } - else - { - /* Error */ - fprintf(stderr, "Error %d on recv\n", errno); - client_going = 0; - } - } - } - WTSVirtualChannelClose(channel); + rdp_fd = *retdata; + int client_going = 1; + + while (client_going) + { + /* Wait for data from RDP or the client */ + fd_set readfds; + FD_ZERO(&readfds); + FD_SET(client_fd, &readfds); + FD_SET(rdp_fd, &readfds); + select(FD_SETSIZE, &readfds, NULL, NULL, NULL); + + if (FD_ISSET(rdp_fd, &readfds)) + { + /* Read from RDP and write to the client */ + char buffer[4096]; + unsigned int bytes_to_write; + rc = WTSVirtualChannelRead(channel, + /* TimeOut = */ 5000, + buffer, + sizeof(buffer), + &bytes_to_write); + + if (rc == 1) + { + char* pos = buffer; + errno = 0; + + while (bytes_to_write > 0) + { + int bytes_written = send(client_fd, pos, bytes_to_write, 0); + + if (bytes_written > 0) + { + bytes_to_write -= bytes_written; + pos += bytes_written; + } + else if (bytes_written == 0) + { + fprintf(stderr, "send() returned 0!\n"); + } + else if (errno != EINTR) + { + /* Error */ + fprintf(stderr, "Error %d on recv\n", errno); + client_going = 0; + } + } + } + else + { + /* Error */ + fprintf(stderr, "WTSVirtualChannelRead() failed: %d\n", errno); + client_going = 0; + } + } + + if (FD_ISSET(client_fd, &readfds)) + { + /* Read from the client and write to RDP */ + char buffer[4096]; + ssize_t bytes_to_write = recv(client_fd, buffer, sizeof(buffer), 0); + + if (bytes_to_write > 0) + { + char* pos = buffer; + + while (bytes_to_write > 0) + { + unsigned int bytes_written; + int rc = WTSVirtualChannelWrite(channel, + pos, + bytes_to_write, + &bytes_written); + + if (rc == 0) + { + fprintf(stderr, "WTSVirtualChannelWrite() failed: %d\n", + errno); + client_going = 0; + } + else + { + bytes_to_write -= bytes_written; + pos += bytes_written; + } + } + } + else if (bytes_to_write == 0) + { + /* Client has closed connection */ + client_going = 0; + } + else + { + /* Error */ + fprintf(stderr, "Error %d on recv\n", errno); + client_going = 0; + } + } + } + + WTSVirtualChannelClose(channel); } int -main(int argc, char **argv) +main(int argc, char** argv) { - /* Setup the Unix domain socket and daemon process */ - struct sockaddr_un addr; - setup_ssh_agent(&addr); + /* Setup the Unix domain socket and daemon process */ + struct sockaddr_un addr; + setup_ssh_agent(&addr); - /* Wait for a client to connect to the socket */ - while (is_going) - { - fd_set readfds; - FD_ZERO(&readfds); - FD_SET(sa_uds_fd, &readfds); - select(FD_SETSIZE, &readfds, NULL, NULL, NULL); + /* Wait for a client to connect to the socket */ + while (is_going) + { + fd_set readfds; + FD_ZERO(&readfds); + FD_SET(sa_uds_fd, &readfds); + select(FD_SETSIZE, &readfds, NULL, NULL, NULL); - /* If something connected then get it... - * (You can test this using "socat - UNIX-CONNECT:".) */ - if (FD_ISSET(sa_uds_fd, &readfds)) - { - socklen_t addrsize = sizeof(addr); - int client_fd = accept(sa_uds_fd, - (struct sockaddr*)&addr, - &addrsize); - handle_connection(client_fd); - close(client_fd); - } - } + /* If something connected then get it... + * (You can test this using "socat - UNIX-CONNECT:".) */ + if (FD_ISSET(sa_uds_fd, &readfds)) + { + socklen_t addrsize = sizeof(addr); + int client_fd = accept(sa_uds_fd, + (struct sockaddr*)&addr, + &addrsize); + handle_connection(client_fd); + close(client_fd); + } + } - close(sa_uds_fd); - unlink(socket_name); - - return 0; + close(sa_uds_fd); + unlink(socket_name); + return 0; } /* vim: set sw=4:ts=4:et: */