Merge pull request #2497 from matt335672/use_poll

Use poll() instead of select()
This commit is contained in:
matt335672 2023-02-13 14:37:00 +00:00 committed by GitHub
commit 54db636e76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 251 additions and 245 deletions

View File

@ -44,6 +44,7 @@
#if defined(XRDP_ENABLE_VSOCK)
#include <linux/vm_sockets.h>
#endif
#include <poll.h>
#include <sys/un.h>
#include <sys/time.h>
#include <sys/times.h>
@ -1594,26 +1595,24 @@ g_sck_socket_ok(int sck)
int
g_sck_can_send(int sck, int millis)
{
fd_set wfds;
struct timeval time;
int rv;
time.tv_sec = millis / 1000;
time.tv_usec = (millis * 1000) % 1000000;
FD_ZERO(&wfds);
int rv = 0;
if (sck > 0)
{
FD_SET(((unsigned int)sck), &wfds);
rv = select(sck + 1, 0, &wfds, 0, &time);
struct pollfd pollfd;
if (rv > 0)
pollfd.fd = sck;
pollfd.events = POLLOUT;
pollfd.revents = 0;
if (poll(&pollfd, 1, millis) > 0)
{
return 1;
if ((pollfd.revents & POLLOUT) != 0)
{
rv = 1;
}
}
}
return 0;
return rv;
}
/*****************************************************************************/
@ -1622,78 +1621,65 @@ g_sck_can_send(int sck, int millis)
int
g_sck_can_recv(int sck, int millis)
{
fd_set rfds;
struct timeval time;
int rv;
g_memset(&time, 0, sizeof(time));
time.tv_sec = millis / 1000;
time.tv_usec = (millis * 1000) % 1000000;
FD_ZERO(&rfds);
int rv = 0;
if (sck > 0)
{
FD_SET(((unsigned int)sck), &rfds);
rv = select(sck + 1, &rfds, 0, 0, &time);
struct pollfd pollfd;
if (rv > 0)
pollfd.fd = sck;
pollfd.events = POLLIN;
pollfd.revents = 0;
if (poll(&pollfd, 1, millis) > 0)
{
return 1;
if ((pollfd.revents & (POLLIN | POLLHUP)) != 0)
{
rv = 1;
}
}
}
return 0;
return rv;
}
/*****************************************************************************/
int
g_sck_select(int sck1, int sck2)
{
fd_set rfds;
struct timeval time;
int max;
int rv;
struct pollfd pollfd[2] = {0};
int rvmask[2] = {0}; /* Output masks corresponding to fds in pollfd */
g_memset(&time, 0, sizeof(struct timeval));
FD_ZERO(&rfds);
unsigned int i = 0;
int rv = 0;
if (sck1 > 0)
{
FD_SET(((unsigned int)sck1), &rfds);
pollfd[i].fd = sck1;
pollfd[i].events = POLLIN;
rvmask[i] = 1;
++i;
}
if (sck2 > 0)
{
FD_SET(((unsigned int)sck2), &rfds);
pollfd[i].fd = sck2;
pollfd[i].events = POLLIN;
rvmask[i] = 2;
++i;
}
max = sck1;
if (sck2 > max)
if (poll(pollfd, i, 0) > 0)
{
max = sck2;
}
rv = select(max + 1, &rfds, 0, 0, &time);
if (rv > 0)
{
rv = 0;
if (FD_ISSET(((unsigned int)sck1), &rfds))
if ((pollfd[0].revents & (POLLIN | POLLHUP)) != 0)
{
rv = rv | 1;
rv |= rvmask[0];
}
if (FD_ISSET(((unsigned int)sck2), &rfds))
if ((pollfd[1].revents & (POLLIN | POLLHUP)) != 0)
{
rv = rv | 2;
rv |= rvmask[1];
}
}
else
{
rv = 0;
}
return rv;
}
@ -1703,19 +1689,24 @@ g_sck_select(int sck1, int sck2)
static int
g_fd_can_read(int fd)
{
fd_set rfds;
struct timeval time;
int rv;
g_memset(&time, 0, sizeof(time));
FD_ZERO(&rfds);
FD_SET(((unsigned int)fd), &rfds);
rv = select(fd + 1, &rfds, 0, 0, &time);
if (rv == 1)
int rv = 0;
if (fd > 0)
{
return 1;
struct pollfd pollfd;
pollfd.fd = fd;
pollfd.events = POLLIN;
pollfd.revents = 0;
if (poll(&pollfd, 1, 0) > 0)
{
if ((pollfd.revents & (POLLIN | POLLHUP)) != 0)
{
rv = 1;
}
}
}
return 0;
return rv;
}
/*****************************************************************************/
@ -1982,8 +1973,9 @@ int
g_obj_wait(tintptr *read_objs, int rcount, tintptr *write_objs, int wcount,
int mstimeout)
{
#define MAX_HANDLES 256
#ifdef _WIN32
HANDLE handles[256];
HANDLE handles[MAX_HANDLES];
DWORD count;
DWORD error;
int j;
@ -2016,96 +2008,62 @@ g_obj_wait(tintptr *read_objs, int rcount, tintptr *write_objs, int wcount,
return 0;
#else
fd_set rfds;
fd_set wfds;
struct timeval time;
struct timeval *ptime;
struct pollfd pollfd[MAX_HANDLES];
int sck;
int i = 0;
int res = 0;
int max = 0;
int sck = 0;
unsigned int j = 0;
int rv = 1;
max = 0;
if (mstimeout < 1)
if (read_objs == NULL && rcount != 0)
{
ptime = 0;
LOG(LOG_LEVEL_ERROR, "Programming error read_objs is null");
}
else if (write_objs == NULL && wcount != 0)
{
LOG(LOG_LEVEL_ERROR, "Programming error write_objs is null");
}
/* Check carefully for int overflow in passed-in counts */
else if ((unsigned int)rcount > MAX_HANDLES ||
(unsigned int)wcount > MAX_HANDLES ||
((unsigned int)rcount + (unsigned int)wcount) > MAX_HANDLES)
{
LOG(LOG_LEVEL_ERROR, "Programming error too many handles");
}
else
{
g_memset(&time, 0, sizeof(struct timeval));
time.tv_sec = mstimeout / 1000;
time.tv_usec = (mstimeout % 1000) * 1000;
ptime = &time;
}
if (mstimeout < 1)
{
mstimeout = -1;
}
FD_ZERO(&rfds);
FD_ZERO(&wfds);
/* Find the highest descriptor number in read_obj */
if (read_objs != NULL)
{
for (i = 0; i < rcount; i++)
for (i = 0; i < rcount ; ++i)
{
sck = read_objs[i] & 0xffff;
if (sck > 0)
{
FD_SET(sck, &rfds);
if (sck > max)
{
max = sck; /* max holds the highest socket/descriptor number */
}
pollfd[j].fd = sck;
pollfd[j].events = POLLIN;
++j;
}
}
}
else if (rcount > 0)
{
LOG(LOG_LEVEL_ERROR, "Programming error read_objs is null");
return 1; /* error */
}
if (write_objs != NULL)
{
for (i = 0; i < wcount; i++)
for (i = 0; i < wcount; ++i)
{
sck = (int)(write_objs[i]);
sck = write_objs[i];
if (sck > 0)
{
FD_SET(sck, &wfds);
if (sck > max)
{
max = sck; /* max holds the highest socket/descriptor number */
}
pollfd[j].fd = sck;
pollfd[j].events = POLLOUT;
++j;
}
}
}
else if (wcount > 0)
{
LOG(LOG_LEVEL_ERROR, "Programming error write_objs is null");
return 1; /* error */
rv = (poll(pollfd, j, mstimeout) < 0);
}
res = select(max + 1, &rfds, &wfds, 0, ptime);
if (res < 0)
{
/* these are not really errors */
if ((errno == EAGAIN) ||
(errno == EWOULDBLOCK) ||
(errno == EINPROGRESS) ||
(errno == EINTR)) /* signal occurred */
{
return 0;
}
return 1; /* error */
}
return 0;
return rv;
#endif
#undef MAX_HANDLES
}
/*****************************************************************************/

View File

@ -126,7 +126,23 @@ int g_sck_send_fd_set(int sck, const void *ptr, unsigned int len,
int fds[], unsigned int fdcount);
int g_sck_last_error_would_block(int sck);
int g_sck_socket_ok(int sck);
/**
* Checks socket writeability with an optional wait
*
* @param sck - Socket to check
* @param millis - Maximum milliseconds to wait for writeability to be true
*
* @note The wait time may not be reached in the event of an incoming signal
* so do not use this call to impose a hard timeout */
int g_sck_can_send(int sck, int millis);
/**
* Checks socket readability with an optional wait
*
* @param sck - Socket to check
* @param millis - Maximum milliseconds to wait for readability to be true
*
* @note The wait time may not be reached in the event of an incoming signal
* so do not use this call to impose a hard timeout */
int g_sck_can_recv(int sck, int millis);
int g_sck_select(int sck1, int sck2);
@ -167,6 +183,21 @@ int g_set_wait_obj(tintptr obj);
int g_reset_wait_obj(tintptr obj);
int g_is_wait_obj_set(tintptr obj);
int g_delete_wait_obj(tintptr obj);
/**
* Wait for the specified readable and writeable objs
*
* The wait finishes when at least one of the objects becomes
* readable or writeable
*
* @param read_objs Array of read objects
* @param rcount Number of elements in read_objs
* @param write_objs Array of write objects
* @param rcount Number of elements in write_objs
* @param mstimeout Timeout in milliseconds. <= 0 means an infinite timeout.
*
* @return 0 for success. The objects will need to be polled to
* find out what is readable or writeable.
*/
int g_obj_wait(tintptr *read_objs, int rcount, tintptr *write_objs,
int wcount, int mstimeout);
void g_random(char *data, int len);

View File

@ -675,7 +675,7 @@ int xfuse_check_wait_objs(void)
return 0;
}
if (g_tcp_select(g_fd, 0) & 1)
if (g_sck_can_recv(g_fd, 0))
{
tmpch = g_ch;

View File

@ -10,6 +10,7 @@
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/stat.h>
#include <poll.h>
#include "string_calls.h"
@ -263,29 +264,24 @@ get_message(int *code, char *data, int *bytes)
int max_bytes;
int error;
int recv_rv;
int max;
int lcode;
struct timeval time;
fd_set rd_set;
struct pollfd pollfd;
LLOGLN(10, ("get_message:"));
max = g_sck + 1;
while (1)
{
LLOGLN(10, ("get_message: loop"));
time.tv_sec = 1;
time.tv_usec = 0;
FD_ZERO(&rd_set);
FD_SET(((unsigned int)g_sck), &rd_set);
error = select(max, &rd_set, 0, 0, &time);
pollfd.fd = g_sck;
pollfd.events = POLLIN;
pollfd.revents = 0;
error = poll(&pollfd, 1, 1000);
if (error == 1)
{
pthread_mutex_lock(&g_mutex);
time.tv_sec = 0;
time.tv_usec = 0;
FD_ZERO(&rd_set);
FD_SET(((unsigned int)g_sck), &rd_set);
error = select(max, &rd_set, 0, 0, &time);
pollfd.fd = g_sck;
pollfd.events = POLLIN;
pollfd.revents = 0;
error = poll(&pollfd, 1, 0);
if (error == 1)
{
/* just take a look at the next message */

View File

@ -24,6 +24,8 @@
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <poll.h>
#include <X11/Xlib.h>
#include <sys/select.h>
@ -32,7 +34,6 @@ int g_x_socket = 0;
int main(int argc, char **argv)
{
fd_set rfds;
int i1;
XEvent xevent;
@ -48,9 +49,15 @@ int main(int argc, char **argv)
while (1)
{
FD_ZERO(&rfds);
FD_SET(g_x_socket, &rfds);
i1 = select(g_x_socket + 1, &rfds, 0, 0, 0);
struct pollfd pollfd;
pollfd.fd = g_x_socket;
pollfd.events = POLLIN;
pollfd.revents = 0;
do
{
i1 = poll(&pollfd, 1, -1);
}
while (i1 < 0 && errno == EINTR);
if (i1 < 0)
{

View File

@ -267,26 +267,24 @@ int tcp_connect(int skt, const char *hostname, const char *port)
int tcp_can_send(int skt, int millis)
{
fd_set wfds;
struct timeval time;
int rv;
time.tv_sec = millis / 1000;
time.tv_usec = (millis * 1000) % 1000000;
FD_ZERO(&wfds);
int rv = 0;
if (skt > 0)
{
FD_SET(((unsigned int) skt), &wfds);
rv = select(skt + 1, 0, &wfds, 0, &time);
struct pollfd pollfd;
if (rv > 0)
pollfd.fd = skt;
pollfd.events = POLLOUT;
pollfd.revents = 0;
if (poll(&pollfd, 1, millis) > 0)
{
return tcp_socket_ok(skt);
if ((pollfd.revents & POLLOUT) != 0)
{
rv = 1;
}
}
}
return 0;
return rv;
}
/**
@ -317,56 +315,40 @@ int tcp_socket_ok(int skt)
int tcp_select(int sck1, int sck2)
{
fd_set rfds;
struct timeval time;
struct pollfd pollfd[2] = {0};
int rvmask[2] = {0}; /* Output masks corresponding to fds in pollfd */
int max = 0;
int rv = 0;
memset(&rfds, 0, sizeof(fd_set));
memset(&time, 0, sizeof(struct timeval));
time.tv_sec = 0;
time.tv_usec = 0;
FD_ZERO(&rfds);
unsigned int i = 0;
int rv = 0;
if (sck1 > 0)
{
FD_SET(((unsigned int) sck1), &rfds);
pollfd[i].fd = sck1;
pollfd[i].events = POLLIN;
rvmask[i] = 1;
++i;
}
if (sck2 > 0)
{
FD_SET(((unsigned int) sck2), &rfds);
pollfd[i].fd = sck2;
pollfd[i].events = POLLIN;
rvmask[i] = 2;
++i;
}
max = sck1;
if (sck2 > max)
if (poll(pollfd, i, 0) > 0)
{
max = sck2;
}
rv = select(max + 1, &rfds, 0, 0, &time);
if (rv > 0)
{
rv = 0;
if (FD_ISSET(((unsigned int) sck1), &rfds))
if ((pollfd[0].revents & (POLLIN | POLLHUP)) != 0)
{
rv = rv | 1;
rv |= rvmask[0];
}
if (FD_ISSET(((unsigned int)sck2), &rfds))
if ((pollfd[1].revents & (POLLIN | POLLHUP)) != 0)
{
rv = rv | 2;
rv |= rvmask[1];
}
}
else
{
rv = 0;
}
return rv;
}

View File

@ -105,6 +105,7 @@
#include <fcntl.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <poll.h>
#define _PATH_DEVNULL "/dev/null"
@ -275,14 +276,24 @@ handle_connection(int client_fd)
int client_going = 1;
while (client_going)
{
/* Wait for data from RDP or the client */
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(client_fd, &readfds);
FD_SET(rdp_fd, &readfds);
select(FD_SETSIZE, &readfds, NULL, NULL, NULL);
struct pollfd pollfd[2];
enum
{
RDP_FD = 0,
CLIENT_FD
};
if (FD_ISSET(rdp_fd, &readfds))
/* Wait for data from RDP or the client */
pollfd[RDP_FD].fd = rdp_fd;
pollfd[RDP_FD].events = POLLIN;
pollfd[RDP_FD].revents = 0;
pollfd[CLIENT_FD].fd = client_fd;
pollfd[CLIENT_FD].events = POLLIN;
pollfd[CLIENT_FD].revents = 0;
poll(pollfd, 2, -1);
if ((pollfd[RDP_FD].revents & (POLLIN | POLLHUP)) != 0)
{
/* Read from RDP and write to the client */
char buffer[4096];
@ -325,7 +336,7 @@ handle_connection(int client_fd)
}
}
if (FD_ISSET(client_fd, &readfds))
if ((pollfd[CLIENT_FD].revents & (POLLIN | POLLHUP)) != 0)
{
/* Read from the client and write to RDP */
char buffer[4096];
@ -380,14 +391,18 @@ main(int argc, char **argv)
/* Wait for a client to connect to the socket */
while (is_going)
{
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(sa_uds_fd, &readfds);
select(FD_SETSIZE, &readfds, NULL, NULL, NULL);
struct pollfd pollfd;
pollfd.fd = sa_uds_fd;
pollfd.events = POLLIN;
pollfd.revents = 0;
poll(pollfd, 1, -1);
/* If something connected then get it...
* (You can test this using "socat - UNIX-CONNECT:<udspath>".) */
if (FD_ISSET(sa_uds_fd, &readfds))
if ((pollfd.revents & (POLLIN | POLLHUP)) != 0)
{
socklen_t addrsize = sizeof(addr);
int client_fd = accept(sa_uds_fd,

View File

@ -32,6 +32,7 @@
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <poll.h>
#include "log.h"
#include "xrdp_sockets.h"
@ -46,9 +47,9 @@ struct wts_obj
/* helper functions used by WTSxxx API - do not invoke directly */
static int
can_send(int sck, int millis);
can_send(int sck, int millis, int restart);
static int
can_recv(int sck, int millis);
can_recv(int sck, int millis, int restart);
static int
mysend(int sck, const void *adata, int bytes);
static int
@ -160,7 +161,7 @@ WTSVirtualChannelOpenEx(unsigned int SessionId, const char *pVirtualName,
}
/* wait for connection to complete */
if (!can_send(wts->fd, 500))
if (!can_send(wts->fd, 500, 1))
{
LOG(LOG_LEVEL_ERROR, "WTSVirtualChannelOpenEx: can_send failed");
free(wts);
@ -212,7 +213,7 @@ WTSVirtualChannelOpenEx(unsigned int SessionId, const char *pVirtualName,
}
LOG_DEVEL(LOG_LEVEL_DEBUG, "WTSVirtualChannelOpenEx: sent ok");
if (!can_recv(wts->fd, 500))
if (!can_recv(wts->fd, 500, 1))
{
LOG(LOG_LEVEL_ERROR, "WTSVirtualChannelOpenEx: can_recv failed");
free(wts);
@ -263,7 +264,7 @@ mysend(int sck, const void *adata, int bytes)
sent = 0;
while (sent < bytes)
{
if (can_send(sck, 100))
if (can_send(sck, 100, 0))
{
error = send(sck, data + sent, bytes - sent, MSG_NOSIGNAL);
if (error < 1)
@ -293,7 +294,7 @@ myrecv(int sck, void *adata, int bytes)
recd = 0;
while (recd < bytes)
{
if (can_recv(sck, 100))
if (can_recv(sck, 100, 0))
{
error = recv(sck, data + recd, bytes - recd, MSG_NOSIGNAL);
if (error < 1)
@ -328,7 +329,7 @@ WTSVirtualChannelWrite(void *hChannelHandle, const char *Buffer,
return 0;
}
if (!can_send(wts->fd, 0))
if (!can_send(wts->fd, 0, 0))
{
return 1; /* can't write now, ok to try again */
}
@ -369,7 +370,7 @@ WTSVirtualChannelRead(void *hChannelHandle, unsigned int TimeOut,
return 0;
}
if (can_recv(wts->fd, TimeOut))
if (can_recv(wts->fd, TimeOut, 0))
{
rv = recv(wts->fd, Buffer, BufferSize, 0);
@ -474,47 +475,63 @@ WTSFreeMemory(void *pMemory)
*
* @param sck socket to check
* @param millis timeout value in milliseconds
* @param restart Try again if interrupted, even if this exceeds the timeout
*
* @return 0 if write will block
* @return 1 if write will not block
******************************************************************************/
static int
can_send(int sck, int millis)
can_send(int sck, int millis, int restart)
{
struct timeval time;
fd_set wfds;
int select_rv;
int rv = 0;
struct pollfd pollfd;
int status;
/* setup for a select call */
FD_ZERO(&wfds);
FD_SET(sck, &wfds);
time.tv_sec = millis / 1000;
time.tv_usec = (millis * 1000) % 1000000;
pollfd.fd = sck;
pollfd.events = POLLOUT;
pollfd.revents = 0;
/* check if it is ok to write to specified socket */
select_rv = select(sck + 1, 0, &wfds, 0, &time);
do
{
status = poll(&pollfd, 1, millis);
}
while (status < 0 && errno == EINTR && restart);
return (select_rv > 0) ? 1 : 0;
if (status > 0)
{
if ((pollfd.revents & POLLOUT) != 0)
{
rv = 1;
}
}
return rv;
}
/*****************************************************************************/
static int
can_recv(int sck, int millis)
can_recv(int sck, int millis, int restart)
{
struct timeval time;
fd_set rfds;
int select_rv;
int rv = 0;
struct pollfd pollfd;
int status;
FD_ZERO(&rfds);
FD_SET(sck, &rfds);
time.tv_sec = millis / 1000;
time.tv_usec = (millis * 1000) % 1000000;
select_rv = select(sck + 1, &rfds, 0, 0, &time);
if (select_rv > 0)
pollfd.fd = sck;
pollfd.events = POLLIN;
pollfd.revents = 0;
do
{
return 1;
status = poll(&pollfd, 1, millis);
}
while (status < 0 && errno == EINTR && restart);
if (status > 0)
{
if ((pollfd.revents & (POLLIN | POLLHUP)) != 0)
{
rv = 1;
}
}
return 0;
return rv;
}