From acc5e2d301c36bc05933fffef9f8adc4820e3115 Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Sat, 25 Feb 2023 11:08:26 +0100 Subject: [PATCH] [client,common] use non blocking IO when reading from stdin use non blocking IO so that we can check if the session terminated in between. --- client/common/client.c | 56 ++++++--- client/common/cmdline.c | 9 +- include/freerdp/utils/passphrase.h | 10 +- libfreerdp/utils/passphrase.c | 195 ++++++++++++++++++++++++----- 4 files changed, 212 insertions(+), 58 deletions(-) diff --git a/client/common/client.c b/client/common/client.c index 027c92927..121309f46 100644 --- a/client/common/client.c +++ b/client/common/client.c @@ -441,10 +441,12 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso { size_t username_size = 0; printf("%s", prompt[0]); + fflush(stdout); - if (GetLine(username, &username_size, stdin) < 0) + if (freerdp_interruptible_get_line(instance->context, username, &username_size, stdin) < 0) { - WLog_ERR(TAG, "GetLine returned %s [%d]", strerror(errno), errno); + WLog_ERR(TAG, "freerdp_interruptible_get_line returned %s [%d]", strerror(errno), + errno); goto fail; } @@ -459,10 +461,12 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso { size_t domain_size = 0; printf("%s", prompt[1]); + fflush(stdout); - if (GetLine(domain, &domain_size, stdin) < 0) + if (freerdp_interruptible_get_line(instance->context, domain, &domain_size, stdin) < 0) { - WLog_ERR(TAG, "GetLine returned %s [%d]", strerror(errno), errno); + WLog_ERR(TAG, "freerdp_interruptible_get_line returned %s [%d]", strerror(errno), + errno); goto fail; } @@ -480,7 +484,7 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso if (!*password) goto fail; - if (freerdp_passphrase_read(prompt[2], *password, password_size, + if (freerdp_passphrase_read(instance->context, prompt[2], *password, password_size, instance->context->settings->CredentialsFromStdin) == NULL) goto fail; } @@ -588,10 +592,16 @@ BOOL client_cli_gw_authenticate(freerdp* instance, char** username, char** passw } #endif -static DWORD client_cli_accept_certificate(rdpSettings* settings) +static DWORD client_cli_accept_certificate(freerdp* instance) { int answer; + WINPR_ASSERT(instance); + WINPR_ASSERT(instance->context); + + const rdpSettings* settings = instance->context->settings; + WINPR_ASSERT(settings); + if (settings->CredentialsFromStdin) return 0; @@ -599,9 +609,9 @@ static DWORD client_cli_accept_certificate(rdpSettings* settings) { printf("Do you trust the above certificate? (Y/T/N) "); fflush(stdout); - answer = fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); - if (feof(stdin)) + if ((answer == EOF) || feof(stdin)) { printf("\nError: Could not read answer from stdin."); @@ -616,17 +626,23 @@ static DWORD client_cli_accept_certificate(rdpSettings* settings) { case 'y': case 'Y': - fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); + if (answer == EOF) + return 0; return 1; case 't': case 'T': - fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); + if (answer == EOF) + return 0; return 2; case 'n': case 'N': - fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); + if (answer == EOF) + return 0; return 0; default: @@ -665,7 +681,7 @@ DWORD client_cli_verify_certificate(freerdp* instance, const char* common_name, printf("The above X.509 certificate could not be verified, possibly because you do not have\n" "the CA certificate in your certificate store, or the certificate has expired.\n" "Please look at the OpenSSL documentation on how to add a private CA to the store.\n"); - return client_cli_accept_certificate(instance->settings); + return client_cli_accept_certificate(instance); } #endif @@ -719,7 +735,7 @@ DWORD client_cli_verify_certificate_ex(freerdp* instance, const char* host, UINT printf("The above X.509 certificate could not be verified, possibly because you do not have\n" "the CA certificate in your certificate store, or the certificate has expired.\n" "Please look at the OpenSSL documentation on how to add a private CA to the store.\n"); - return client_cli_accept_certificate(instance->context->settings); + return client_cli_accept_certificate(instance); } /** Callback set in the rdp_freerdp structure, and used to make a certificate validation @@ -763,7 +779,7 @@ DWORD client_cli_verify_changed_certificate(freerdp* instance, const char* commo "connections.\n" "This may indicate that the certificate has been tampered with.\n" "Please contact the administrator of the RDP server and clarify.\n"); - return client_cli_accept_certificate(instance->settings); + return client_cli_accept_certificate(instance); } #endif @@ -849,7 +865,7 @@ DWORD client_cli_verify_changed_certificate_ex(freerdp* instance, const char* ho "connections.\n" "This may indicate that the certificate has been tampered with.\n" "Please contact the administrator of the RDP server and clarify.\n"); - return client_cli_accept_certificate(instance->context->settings); + return client_cli_accept_certificate(instance); } BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isDisplayMandatory, @@ -886,9 +902,9 @@ BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isD { printf("I understand and agree to the terms of this policy (Y/N) \n"); fflush(stdout); - answer = fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); - if (feof(stdin)) + if ((answer == EOF) || feof(stdin)) { printf("\nError: Could not read answer from stdin.\n"); return FALSE; @@ -898,12 +914,14 @@ BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isD { case 'y': case 'Y': - fgetc(stdin); + answer = freerdp_interruptible_getc(instance->context, stdin); + if (answer == EOF) + return FALSE; return TRUE; case 'n': case 'N': - fgetc(stdin); + freerdp_interruptible_getc(instance->context, stdin); return FALSE; default: diff --git a/client/common/cmdline.c b/client/common/cmdline.c index 2d06b54a7..4ecb28c62 100644 --- a/client/common/cmdline.c +++ b/client/common/cmdline.c @@ -4297,7 +4297,9 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings, if (!settings->Password) return COMMAND_LINE_ERROR; - if (!freerdp_passphrase_read("Password: ", settings->Password, size, 1)) + freerdp* instance = settings->instance; + if (!freerdp_passphrase_read(instance->context, "Password: ", settings->Password, size, + 1)) return COMMAND_LINE_ERROR; } @@ -4310,8 +4312,9 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings, if (!settings->GatewayPassword) return COMMAND_LINE_ERROR; - if (!freerdp_passphrase_read("Gateway Password: ", settings->GatewayPassword, size, - 1)) + freerdp* instance = settings->instance; + if (!freerdp_passphrase_read(instance->context, "Gateway Password: ", + settings->GatewayPassword, size, 1)) return COMMAND_LINE_ERROR; } } diff --git a/include/freerdp/utils/passphrase.h b/include/freerdp/utils/passphrase.h index 045f06437..0ca48d117 100644 --- a/include/freerdp/utils/passphrase.h +++ b/include/freerdp/utils/passphrase.h @@ -21,15 +21,21 @@ #define FREERDP_UTILS_PASSPHRASE_H #include +#include + #include +#include #ifdef __cplusplus extern "C" { #endif - FREERDP_API char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, - int from_stdin); + FREERDP_API int freerdp_interruptible_getc(rdpContext* context, FILE* file); + FREERDP_API SSIZE_T freerdp_interruptible_get_line(rdpContext* context, char** lineptr, + size_t* size, FILE* stream); + FREERDP_API char* freerdp_passphrase_read(rdpContext* context, const char* prompt, char* buf, + size_t bufsiz, int from_stdin); #ifdef __cplusplus } diff --git a/libfreerdp/utils/passphrase.c b/libfreerdp/utils/passphrase.c index 8c5ef9390..cbe38ef13 100644 --- a/libfreerdp/utils/passphrase.c +++ b/libfreerdp/utils/passphrase.c @@ -106,16 +106,70 @@ fail: #include #include #include +#include #include -char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int from_stdin) +#ifdef WINPR_HAVE_POLL_H +#include +#else +#include +#include +#endif + +static int wait_for_fd(int fd, int timeout) { - char read_char; - char* buf_iter; - char term_name[L_ctermid]; - int term_file, write_file; - ssize_t nbytes; - size_t read_bytes = 0; + int status; +#ifdef WINPR_HAVE_POLL_H + struct pollfd pollset = { 0 }; + pollset.fd = fd; + pollset.events = POLLIN; + pollset.revents = 0; + + do + { + status = poll(&pollset, 1, timeout); + } while ((status < 0) && (errno == EINTR)); + +#else + fd_set rset = { 0 }; + struct timeval tv = { 0 }; + FD_ZERO(&rset); + FD_SET(sockfd, &rset); + + if (timeout) + { + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + } + + do + { + status = select(sockfd + 1, &rset, NULL, NULL, timeout ? &tv : NULL); + } while ((status < 0) && (errno == EINTR)); + +#endif + return status; +} + +static void replace_char(char* buffer, size_t buffer_len, const char* toreplace) +{ + while (*toreplace != '\0') + { + char* ptr; + while ((ptr = strrchr(buffer, *toreplace)) != NULL) + *ptr = '\0'; + toreplace++; + } +} + +char* freerdp_passphrase_read(rdpContext* context, const char* prompt, char* buf, size_t bufsiz, + int from_stdin) +{ + BOOL terminal_needs_reset = FALSE; + char term_name[L_ctermid] = { 0 }; + int term_file; + + FILE* fout = NULL; if (bufsiz == 0) { @@ -124,58 +178,56 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int } ctermid(term_name); + int terminal_fildes; if (from_stdin || strcmp(term_name, "") == 0 || (term_file = open(term_name, O_RDWR)) == -1) { - write_file = STDERR_FILENO; + fout = stdout; terminal_fildes = STDIN_FILENO; } else { - write_file = term_file; + fout = fdopen(term_file, "w"); terminal_fildes = term_file; } + struct termios orig_flags = { 0 }; if (tcgetattr(terminal_fildes, &orig_flags) != -1) { + struct termios new_flags = { 0 }; new_flags = orig_flags; new_flags.c_lflag &= ~ECHO; new_flags.c_lflag |= ECHONL; - terminal_needs_reset = 1; + terminal_needs_reset = TRUE; if (tcsetattr(terminal_fildes, TCSAFLUSH, &new_flags) == -1) - terminal_needs_reset = 0; + terminal_needs_reset = FALSE; } - if (write(write_file, prompt, strlen(prompt)) == (ssize_t)-1) + FILE* fp = fdopen(terminal_fildes, "r"); + if (!fp) goto error; - buf_iter = buf; - while ((nbytes = read(terminal_fildes, &read_char, sizeof read_char)) == (sizeof read_char)) - { - if (read_char == '\n') - break; - if (read_bytes < (bufsiz - (size_t)1)) - { - read_bytes++; - *buf_iter = read_char; - buf_iter++; - } - } - *buf_iter = '\0'; - buf_iter = NULL; - read_char = '\0'; - if (nbytes == (ssize_t)-1) - goto error; + fprintf(fout, "%s", prompt); + fflush(fout); + char* ptr = NULL; + size_t ptr_len = 0; + + const SSIZE_T res = freerdp_interruptible_get_line(context, &ptr, &ptr_len, fp); + if (res < 0) + goto error; + replace_char(ptr, ptr_len, "\r\n"); + + strncpy(buf, ptr, MIN(bufsiz, ptr_len)); + free(ptr); if (terminal_needs_reset) { if (tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags) == -1) goto error; - terminal_needs_reset = 0; } if (terminal_fildes != STDIN_FILENO) { - if (close(terminal_fildes) == -1) + if (fclose(fp) == -1) goto error; } @@ -184,17 +236,41 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int error: { int saved_errno = errno; - buf_iter = NULL; - read_char = '\0'; if (terminal_needs_reset) tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags); if (terminal_fildes != STDIN_FILENO) - close(terminal_fildes); + { + fclose(fp); + } errno = saved_errno; return NULL; } } +int freerdp_interruptible_getc(rdpContext* context, FILE* f) +{ + int rc = EOF; + const int fd = fileno(f); + + const int orig = fcntl(fd, F_GETFL); + fcntl(fd, F_SETFL, orig | O_NONBLOCK); + do + { + const int res = wait_for_fd(fd, 10); + if (res != 0) + { + char c = 0; + const ssize_t rd = read(fd, &c, 1); + if (rd == 1) + rc = c; + break; + } + } while (!freerdp_shall_disconnect_context(context)); + + fcntl(fd, F_SETFL, orig); + return rc; +} + #else char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int from_stdin) @@ -202,4 +278,55 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int return NULL; } +int freerdp_interruptible_getc(rdpContext* context, FILE* f) +{ + return EOF; +} #endif + +SSIZE_T freerdp_interruptible_get_line(rdpContext* context, char** plineptr, size_t* psize, + FILE* stream) +{ + char c; + char* n; + size_t step = 32; + size_t used = 0; + char* ptr = NULL; + size_t len = 0; + + if (!plineptr || !psize) + { + errno = EINVAL; + return -1; + } + + do + { + if (used + 2 >= len) + { + len += step; + n = realloc(ptr, len); + + if (!n) + { + return -1; + } + + ptr = n; + } + + c = freerdp_interruptible_getc(context, stream); + if (c != EOF) + ptr[used++] = c; + } while ((c != '\n') && (c != '\r') && (c != EOF)); + + ptr[used] = '\0'; + if (c == EOF) + { + free(ptr); + return EOF; + } + *plineptr = ptr; + *psize = used; + return used; +}