[client,common] use non blocking IO

when reading from stdin use non blocking IO so that we can check if the
session terminated in between.
This commit is contained in:
Armin Novak 2023-02-25 11:08:26 +01:00 committed by akallabeth
parent 4398126dde
commit acc5e2d301
4 changed files with 212 additions and 58 deletions

View File

@ -441,10 +441,12 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso
{ {
size_t username_size = 0; size_t username_size = 0;
printf("%s", prompt[0]); printf("%s", prompt[0]);
fflush(stdout);
if (GetLine(username, &username_size, stdin) < 0) if (freerdp_interruptible_get_line(instance->context, username, &username_size, stdin) < 0)
{ {
WLog_ERR(TAG, "GetLine returned %s [%d]", strerror(errno), errno); WLog_ERR(TAG, "freerdp_interruptible_get_line returned %s [%d]", strerror(errno),
errno);
goto fail; goto fail;
} }
@ -459,10 +461,12 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso
{ {
size_t domain_size = 0; size_t domain_size = 0;
printf("%s", prompt[1]); printf("%s", prompt[1]);
fflush(stdout);
if (GetLine(domain, &domain_size, stdin) < 0) if (freerdp_interruptible_get_line(instance->context, domain, &domain_size, stdin) < 0)
{ {
WLog_ERR(TAG, "GetLine returned %s [%d]", strerror(errno), errno); WLog_ERR(TAG, "freerdp_interruptible_get_line returned %s [%d]", strerror(errno),
errno);
goto fail; goto fail;
} }
@ -480,7 +484,7 @@ static BOOL client_cli_authenticate_raw(freerdp* instance, rdp_auth_reason reaso
if (!*password) if (!*password)
goto fail; goto fail;
if (freerdp_passphrase_read(prompt[2], *password, password_size, if (freerdp_passphrase_read(instance->context, prompt[2], *password, password_size,
instance->context->settings->CredentialsFromStdin) == NULL) instance->context->settings->CredentialsFromStdin) == NULL)
goto fail; goto fail;
} }
@ -588,10 +592,16 @@ BOOL client_cli_gw_authenticate(freerdp* instance, char** username, char** passw
} }
#endif #endif
static DWORD client_cli_accept_certificate(rdpSettings* settings) static DWORD client_cli_accept_certificate(freerdp* instance)
{ {
int answer; int answer;
WINPR_ASSERT(instance);
WINPR_ASSERT(instance->context);
const rdpSettings* settings = instance->context->settings;
WINPR_ASSERT(settings);
if (settings->CredentialsFromStdin) if (settings->CredentialsFromStdin)
return 0; return 0;
@ -599,9 +609,9 @@ static DWORD client_cli_accept_certificate(rdpSettings* settings)
{ {
printf("Do you trust the above certificate? (Y/T/N) "); printf("Do you trust the above certificate? (Y/T/N) ");
fflush(stdout); fflush(stdout);
answer = fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (feof(stdin)) if ((answer == EOF) || feof(stdin))
{ {
printf("\nError: Could not read answer from stdin."); printf("\nError: Could not read answer from stdin.");
@ -616,17 +626,23 @@ static DWORD client_cli_accept_certificate(rdpSettings* settings)
{ {
case 'y': case 'y':
case 'Y': case 'Y':
fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (answer == EOF)
return 0;
return 1; return 1;
case 't': case 't':
case 'T': case 'T':
fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (answer == EOF)
return 0;
return 2; return 2;
case 'n': case 'n':
case 'N': case 'N':
fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (answer == EOF)
return 0;
return 0; return 0;
default: default:
@ -665,7 +681,7 @@ DWORD client_cli_verify_certificate(freerdp* instance, const char* common_name,
printf("The above X.509 certificate could not be verified, possibly because you do not have\n" printf("The above X.509 certificate could not be verified, possibly because you do not have\n"
"the CA certificate in your certificate store, or the certificate has expired.\n" "the CA certificate in your certificate store, or the certificate has expired.\n"
"Please look at the OpenSSL documentation on how to add a private CA to the store.\n"); "Please look at the OpenSSL documentation on how to add a private CA to the store.\n");
return client_cli_accept_certificate(instance->settings); return client_cli_accept_certificate(instance);
} }
#endif #endif
@ -719,7 +735,7 @@ DWORD client_cli_verify_certificate_ex(freerdp* instance, const char* host, UINT
printf("The above X.509 certificate could not be verified, possibly because you do not have\n" printf("The above X.509 certificate could not be verified, possibly because you do not have\n"
"the CA certificate in your certificate store, or the certificate has expired.\n" "the CA certificate in your certificate store, or the certificate has expired.\n"
"Please look at the OpenSSL documentation on how to add a private CA to the store.\n"); "Please look at the OpenSSL documentation on how to add a private CA to the store.\n");
return client_cli_accept_certificate(instance->context->settings); return client_cli_accept_certificate(instance);
} }
/** Callback set in the rdp_freerdp structure, and used to make a certificate validation /** Callback set in the rdp_freerdp structure, and used to make a certificate validation
@ -763,7 +779,7 @@ DWORD client_cli_verify_changed_certificate(freerdp* instance, const char* commo
"connections.\n" "connections.\n"
"This may indicate that the certificate has been tampered with.\n" "This may indicate that the certificate has been tampered with.\n"
"Please contact the administrator of the RDP server and clarify.\n"); "Please contact the administrator of the RDP server and clarify.\n");
return client_cli_accept_certificate(instance->settings); return client_cli_accept_certificate(instance);
} }
#endif #endif
@ -849,7 +865,7 @@ DWORD client_cli_verify_changed_certificate_ex(freerdp* instance, const char* ho
"connections.\n" "connections.\n"
"This may indicate that the certificate has been tampered with.\n" "This may indicate that the certificate has been tampered with.\n"
"Please contact the administrator of the RDP server and clarify.\n"); "Please contact the administrator of the RDP server and clarify.\n");
return client_cli_accept_certificate(instance->context->settings); return client_cli_accept_certificate(instance);
} }
BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isDisplayMandatory, BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isDisplayMandatory,
@ -886,9 +902,9 @@ BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isD
{ {
printf("I understand and agree to the terms of this policy (Y/N) \n"); printf("I understand and agree to the terms of this policy (Y/N) \n");
fflush(stdout); fflush(stdout);
answer = fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (feof(stdin)) if ((answer == EOF) || feof(stdin))
{ {
printf("\nError: Could not read answer from stdin.\n"); printf("\nError: Could not read answer from stdin.\n");
return FALSE; return FALSE;
@ -898,12 +914,14 @@ BOOL client_cli_present_gateway_message(freerdp* instance, UINT32 type, BOOL isD
{ {
case 'y': case 'y':
case 'Y': case 'Y':
fgetc(stdin); answer = freerdp_interruptible_getc(instance->context, stdin);
if (answer == EOF)
return FALSE;
return TRUE; return TRUE;
case 'n': case 'n':
case 'N': case 'N':
fgetc(stdin); freerdp_interruptible_getc(instance->context, stdin);
return FALSE; return FALSE;
default: default:

View File

@ -4297,7 +4297,9 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings,
if (!settings->Password) if (!settings->Password)
return COMMAND_LINE_ERROR; return COMMAND_LINE_ERROR;
if (!freerdp_passphrase_read("Password: ", settings->Password, size, 1)) freerdp* instance = settings->instance;
if (!freerdp_passphrase_read(instance->context, "Password: ", settings->Password, size,
1))
return COMMAND_LINE_ERROR; return COMMAND_LINE_ERROR;
} }
@ -4310,8 +4312,9 @@ int freerdp_client_settings_parse_command_line_arguments(rdpSettings* settings,
if (!settings->GatewayPassword) if (!settings->GatewayPassword)
return COMMAND_LINE_ERROR; return COMMAND_LINE_ERROR;
if (!freerdp_passphrase_read("Gateway Password: ", settings->GatewayPassword, size, freerdp* instance = settings->instance;
1)) if (!freerdp_passphrase_read(instance->context, "Gateway Password: ",
settings->GatewayPassword, size, 1))
return COMMAND_LINE_ERROR; return COMMAND_LINE_ERROR;
} }
} }

View File

@ -21,15 +21,21 @@
#define FREERDP_UTILS_PASSPHRASE_H #define FREERDP_UTILS_PASSPHRASE_H
#include <stdlib.h> #include <stdlib.h>
#include <stdio.h>
#include <freerdp/api.h> #include <freerdp/api.h>
#include <freerdp/freerdp.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" extern "C"
{ {
#endif #endif
FREERDP_API char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, FREERDP_API int freerdp_interruptible_getc(rdpContext* context, FILE* file);
int from_stdin); FREERDP_API SSIZE_T freerdp_interruptible_get_line(rdpContext* context, char** lineptr,
size_t* size, FILE* stream);
FREERDP_API char* freerdp_passphrase_read(rdpContext* context, const char* prompt, char* buf,
size_t bufsiz, int from_stdin);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -106,16 +106,70 @@ fail:
#include <sys/stat.h> #include <sys/stat.h>
#include <termios.h> #include <termios.h>
#include <unistd.h> #include <unistd.h>
#include <termios.h>
#include <freerdp/utils/signal.h> #include <freerdp/utils/signal.h>
char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int from_stdin) #ifdef WINPR_HAVE_POLL_H
#include <poll.h>
#else
#include <time.h>
#include <sys/select.h>
#endif
static int wait_for_fd(int fd, int timeout)
{ {
char read_char; int status;
char* buf_iter; #ifdef WINPR_HAVE_POLL_H
char term_name[L_ctermid]; struct pollfd pollset = { 0 };
int term_file, write_file; pollset.fd = fd;
ssize_t nbytes; pollset.events = POLLIN;
size_t read_bytes = 0; pollset.revents = 0;
do
{
status = poll(&pollset, 1, timeout);
} while ((status < 0) && (errno == EINTR));
#else
fd_set rset = { 0 };
struct timeval tv = { 0 };
FD_ZERO(&rset);
FD_SET(sockfd, &rset);
if (timeout)
{
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
}
do
{
status = select(sockfd + 1, &rset, NULL, NULL, timeout ? &tv : NULL);
} while ((status < 0) && (errno == EINTR));
#endif
return status;
}
static void replace_char(char* buffer, size_t buffer_len, const char* toreplace)
{
while (*toreplace != '\0')
{
char* ptr;
while ((ptr = strrchr(buffer, *toreplace)) != NULL)
*ptr = '\0';
toreplace++;
}
}
char* freerdp_passphrase_read(rdpContext* context, const char* prompt, char* buf, size_t bufsiz,
int from_stdin)
{
BOOL terminal_needs_reset = FALSE;
char term_name[L_ctermid] = { 0 };
int term_file;
FILE* fout = NULL;
if (bufsiz == 0) if (bufsiz == 0)
{ {
@ -124,58 +178,56 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int
} }
ctermid(term_name); ctermid(term_name);
int terminal_fildes;
if (from_stdin || strcmp(term_name, "") == 0 || (term_file = open(term_name, O_RDWR)) == -1) if (from_stdin || strcmp(term_name, "") == 0 || (term_file = open(term_name, O_RDWR)) == -1)
{ {
write_file = STDERR_FILENO; fout = stdout;
terminal_fildes = STDIN_FILENO; terminal_fildes = STDIN_FILENO;
} }
else else
{ {
write_file = term_file; fout = fdopen(term_file, "w");
terminal_fildes = term_file; terminal_fildes = term_file;
} }
struct termios orig_flags = { 0 };
if (tcgetattr(terminal_fildes, &orig_flags) != -1) if (tcgetattr(terminal_fildes, &orig_flags) != -1)
{ {
struct termios new_flags = { 0 };
new_flags = orig_flags; new_flags = orig_flags;
new_flags.c_lflag &= ~ECHO; new_flags.c_lflag &= ~ECHO;
new_flags.c_lflag |= ECHONL; new_flags.c_lflag |= ECHONL;
terminal_needs_reset = 1; terminal_needs_reset = TRUE;
if (tcsetattr(terminal_fildes, TCSAFLUSH, &new_flags) == -1) if (tcsetattr(terminal_fildes, TCSAFLUSH, &new_flags) == -1)
terminal_needs_reset = 0; terminal_needs_reset = FALSE;
} }
if (write(write_file, prompt, strlen(prompt)) == (ssize_t)-1) FILE* fp = fdopen(terminal_fildes, "r");
if (!fp)
goto error; goto error;
buf_iter = buf; fprintf(fout, "%s", prompt);
while ((nbytes = read(terminal_fildes, &read_char, sizeof read_char)) == (sizeof read_char)) fflush(fout);
{
if (read_char == '\n')
break;
if (read_bytes < (bufsiz - (size_t)1))
{
read_bytes++;
*buf_iter = read_char;
buf_iter++;
}
}
*buf_iter = '\0';
buf_iter = NULL;
read_char = '\0';
if (nbytes == (ssize_t)-1)
goto error;
char* ptr = NULL;
size_t ptr_len = 0;
const SSIZE_T res = freerdp_interruptible_get_line(context, &ptr, &ptr_len, fp);
if (res < 0)
goto error;
replace_char(ptr, ptr_len, "\r\n");
strncpy(buf, ptr, MIN(bufsiz, ptr_len));
free(ptr);
if (terminal_needs_reset) if (terminal_needs_reset)
{ {
if (tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags) == -1) if (tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags) == -1)
goto error; goto error;
terminal_needs_reset = 0;
} }
if (terminal_fildes != STDIN_FILENO) if (terminal_fildes != STDIN_FILENO)
{ {
if (close(terminal_fildes) == -1) if (fclose(fp) == -1)
goto error; goto error;
} }
@ -184,17 +236,41 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int
error: error:
{ {
int saved_errno = errno; int saved_errno = errno;
buf_iter = NULL;
read_char = '\0';
if (terminal_needs_reset) if (terminal_needs_reset)
tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags); tcsetattr(terminal_fildes, TCSAFLUSH, &orig_flags);
if (terminal_fildes != STDIN_FILENO) if (terminal_fildes != STDIN_FILENO)
close(terminal_fildes); {
fclose(fp);
}
errno = saved_errno; errno = saved_errno;
return NULL; return NULL;
} }
} }
int freerdp_interruptible_getc(rdpContext* context, FILE* f)
{
int rc = EOF;
const int fd = fileno(f);
const int orig = fcntl(fd, F_GETFL);
fcntl(fd, F_SETFL, orig | O_NONBLOCK);
do
{
const int res = wait_for_fd(fd, 10);
if (res != 0)
{
char c = 0;
const ssize_t rd = read(fd, &c, 1);
if (rd == 1)
rc = c;
break;
}
} while (!freerdp_shall_disconnect_context(context));
fcntl(fd, F_SETFL, orig);
return rc;
}
#else #else
char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int from_stdin) char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int from_stdin)
@ -202,4 +278,55 @@ char* freerdp_passphrase_read(const char* prompt, char* buf, size_t bufsiz, int
return NULL; return NULL;
} }
int freerdp_interruptible_getc(rdpContext* context, FILE* f)
{
return EOF;
}
#endif #endif
SSIZE_T freerdp_interruptible_get_line(rdpContext* context, char** plineptr, size_t* psize,
FILE* stream)
{
char c;
char* n;
size_t step = 32;
size_t used = 0;
char* ptr = NULL;
size_t len = 0;
if (!plineptr || !psize)
{
errno = EINVAL;
return -1;
}
do
{
if (used + 2 >= len)
{
len += step;
n = realloc(ptr, len);
if (!n)
{
return -1;
}
ptr = n;
}
c = freerdp_interruptible_getc(context, stream);
if (c != EOF)
ptr[used++] = c;
} while ((c != '\n') && (c != '\r') && (c != EOF));
ptr[used] = '\0';
if (c == EOF)
{
free(ptr);
return EOF;
}
*plineptr = ptr;
*psize = used;
return used;
}