redo the connection waiting handling to make it more clear.

This commit is contained in:
christos 2019-04-07 00:44:54 +00:00
parent c5cc429c4c
commit e35c1a2b0b

View File

@ -1,4 +1,4 @@
/* $NetBSD: ssl.c,v 1.7 2019/04/04 00:36:09 christos Exp $ */
/* $NetBSD: ssl.c,v 1.8 2019/04/07 00:44:54 christos Exp $ */
/*-
* Copyright (c) 1998-2004 Dag-Erling Coïdan Smørgrav
@ -34,7 +34,7 @@
#include <sys/cdefs.h>
#ifndef lint
__RCSID("$NetBSD: ssl.c,v 1.7 2019/04/04 00:36:09 christos Exp $");
__RCSID("$NetBSD: ssl.c,v 1.8 2019/04/07 00:44:54 christos Exp $");
#endif
#include <time.h>
@ -88,6 +88,7 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
struct timeval now, timeout, delta;
fd_set writefds;
ssize_t len, total;
int fd = conn->sd;
int r;
if (quit_time > 0) {
@ -98,8 +99,8 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
total = 0;
while (iovcnt > 0) {
while (quit_time > 0 && !FD_ISSET(conn->sd, &writefds)) {
FD_SET(conn->sd, &writefds);
while (quit_time > 0 && !FD_ISSET(fd, &writefds)) {
FD_SET(fd, &writefds);
gettimeofday(&now, NULL);
delta.tv_sec = timeout.tv_sec - now.tv_sec;
delta.tv_usec = timeout.tv_usec - now.tv_usec;
@ -112,7 +113,7 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
return -1;
}
errno = 0;
r = select(conn->sd + 1, NULL, &writefds, NULL, &delta);
r = select(fd + 1, NULL, &writefds, NULL, &delta);
if (r == -1) {
if (errno == EINTR)
continue;
@ -123,7 +124,7 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
if (conn->ssl != NULL)
len = SSL_write(conn->ssl, iov->iov_base, iov->iov_len);
else
len = writev(conn->sd, iov, iovcnt);
len = writev(fd, iov, iovcnt);
if (len == 0) {
/* we consider a short write a failure */
/* XXX perhaps we shouldn't in the SSL case */
@ -131,7 +132,7 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
return -1;
}
if (len < 0) {
if (errno == EINTR)
if (errno == EINTR || errno == EAGAIN)
continue;
return -1;
}
@ -149,11 +150,8 @@ fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
return total;
}
/*
* Write to a connection w/ timeout
*/
static int
fetch_write(struct fetch_connect *conn, const char *str, size_t len)
static ssize_t
fetch_write(const void *str, size_t len, struct fetch_connect *conn)
{
struct iovec iov[1];
@ -182,7 +180,7 @@ fetch_printf(struct fetch_connect *conn, const char *fmt, ...)
return -1;
}
r = fetch_write(conn, msg, len);
r = fetch_write(msg, len, conn);
free(msg);
return r;
}
@ -211,15 +209,16 @@ fetch_clearerr(struct fetch_connect *conn)
int
fetch_flush(struct fetch_connect *conn)
{
int v;
if (conn->issock) {
int fd = conn->sd;
int v;
#ifdef TCP_NOPUSH
v = 0;
setsockopt(conn->sd, IPPROTO_TCP, TCP_NOPUSH, &v, sizeof(v));
setsockopt(fd, IPPROTO_TCP, TCP_NOPUSH, &v, sizeof(v));
#endif
v = 1;
setsockopt(conn->sd, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v));
setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v));
}
return 0;
}
@ -272,23 +271,19 @@ fetch_fdopen(int sd, const char *fmode)
int
fetch_close(struct fetch_connect *conn)
{
int rv = 0;
if (conn == NULL)
return 0;
if (conn != NULL) {
fetch_flush(conn);
SSL_free(conn->ssl);
rv = close(conn->sd);
if (rv < 0) {
errno = rv;
rv = EOF;
}
free(conn->cache.buf);
free(conn->buf);
free(conn);
}
return rv;
fetch_flush(conn);
SSL_free(conn->ssl);
close(conn->sd);
free(conn->cache.buf);
free(conn->buf);
free(conn);
return 0;
}
#define FETCH_WRITE_WAIT -3
#define FETCH_READ_WAIT -2
#define FETCH_READ_ERROR -1
@ -296,19 +291,19 @@ static ssize_t
fetch_ssl_read(SSL *ssl, void *buf, size_t len)
{
ssize_t rlen;
int ssl_err;
rlen = SSL_read(ssl, buf, len);
if (rlen < 0) {
ssl_err = SSL_get_error(ssl, rlen);
if (ssl_err == SSL_ERROR_WANT_READ ||
ssl_err == SSL_ERROR_WANT_WRITE) {
return FETCH_READ_WAIT;
}
if (rlen >= 0)
return rlen;
switch (SSL_get_error(ssl, rlen)) {
case SSL_ERROR_WANT_READ:
return FETCH_READ_WAIT;
case SSL_ERROR_WANT_WRITE:
return FETCH_WRITE_WAIT;
default:
ERR_print_errors_fp(ttyout);
return FETCH_READ_ERROR;
}
return rlen;
}
static ssize_t
@ -317,7 +312,7 @@ fetch_nonssl_read(int sd, void *buf, size_t len)
ssize_t rlen;
rlen = read(sd, buf, len);
if (rlen < 0) {
if (rlen == -1) {
if (errno == EAGAIN || errno == EINTR)
return FETCH_READ_WAIT;
return FETCH_READ_ERROR;
@ -348,14 +343,45 @@ fetch_cache_data(struct fetch_connect *conn, char *src, size_t nbytes)
return 0;
}
static int
fetch_wait(struct fetch_connect *conn, ssize_t rlen, struct timeval *timeout)
{
struct timeval now, delta;
int fd = conn->sd;
fd_set fds;
FD_ZERO(&fds);
while (!FD_ISSET(fd, &fds)) {
FD_SET(fd, &fds);
if (quit_time > 0) {
gettimeofday(&now, NULL);
if (!timercmp(timeout, &now, >)) {
conn->iserr = ETIMEDOUT;
return -1;
}
timersub(timeout, &now, &delta);
}
errno = 0;
if (select(fd + 1,
rlen == FETCH_READ_WAIT ? &fds : NULL,
rlen == FETCH_WRITE_WAIT ? &fds : NULL,
NULL, quit_time > 0 ? &delta : NULL) < 0) {
if (errno == EINTR)
continue;
conn->iserr = errno;
return -1;
}
}
return 0;
}
size_t
fetch_read(void *ptr, size_t size, size_t nmemb, struct fetch_connect *conn)
{
struct timeval now, timeout, delta;
fd_set readfds;
ssize_t rlen, total;
size_t len;
char *start, *buf;
struct timeval timeout;
if (quit_time > 0) {
gettimeofday(&timeout, NULL);
@ -407,39 +433,25 @@ fetch_read(void *ptr, size_t size, size_t nmemb, struct fetch_connect *conn)
rlen = fetch_ssl_read(conn->ssl, buf, len);
else
rlen = fetch_nonssl_read(conn->sd, buf, len);
if (rlen == 0) {
switch (rlen) {
case 0:
conn->iseof = 1;
break;
} else if (rlen > 0) {
len -= rlen;
buf += rlen;
total += rlen;
continue;
} else if (rlen == FETCH_READ_ERROR) {
return total;
case FETCH_READ_ERROR:
conn->iserr = errno;
if (errno == EINTR)
fetch_cache_data(conn, start, total);
return 0;
}
FD_ZERO(&readfds);
while (!FD_ISSET(conn->sd, &readfds)) {
FD_SET(conn->sd, &readfds);
if (quit_time > 0) {
gettimeofday(&now, NULL);
if (!timercmp(&timeout, &now, >)) {
conn->iserr = ETIMEDOUT;
return 0;
}
timersub(&timeout, &now, &delta);
}
errno = 0;
if (select(conn->sd + 1, &readfds, NULL, NULL,
quit_time > 0 ? &delta : NULL) < 0) {
if (errno == EINTR)
continue;
conn->iserr = errno;
case FETCH_READ_WAIT:
case FETCH_WRITE_WAIT:
if (fetch_wait(conn, rlen, &timeout) == -1)
return 0;
}
break;
default:
len -= rlen;
buf += rlen;
total += rlen;
break;
}
}
return total;