NetBSD/usr.sbin/sup/source/scmio.c
christos 524dd7b3c3 - don't allow multiple active connections from the same host
- prefix all messages with the local hostname
- more error checking
2013-03-08 20:56:44 +00:00

761 lines
18 KiB
C

/* $NetBSD: scmio.c,v 1.22 2013/03/08 20:56:44 christos Exp $ */
/*
* Copyright (c) 1992 Carnegie Mellon University
* All Rights Reserved.
*
* Permission to use, copy, modify and distribute this software and its
* documentation is hereby granted, provided that both the copyright
* notice and this permission notice appear in all copies of the
* software, derivative works or modified versions, and any portions
* thereof, and that both notices appear in supporting documentation.
*
* CARNEGIE MELLON ALLOWS FREE USE OF THIS SOFTWARE IN ITS "AS IS"
* CONDITION. CARNEGIE MELLON DISCLAIMS ANY LIABILITY OF ANY KIND FOR
* ANY DAMAGES WHATSOEVER RESULTING FROM THE USE OF THIS SOFTWARE.
*
* Carnegie Mellon requests users of this software to return to
*
* Software Distribution Coordinator or Software.Distribution@CS.CMU.EDU
* School of Computer Science
* Carnegie Mellon University
* Pittsburgh PA 15213-3890
*
* any improvements or extensions that they make and grant Carnegie Mellon
* the rights to redistribute these changes.
*/
/*
* SUP Communication Module for 4.3 BSD
*
* SUP COMMUNICATION MODULE SPECIFICATIONS:
*
* IN THIS MODULE:
*
* OUTPUT TO NETWORK
*
* MESSAGE START/END
* writemsg (msg) start message
* int msg; message type
* writemend () end message and flush data to network
*
* MESSAGE DATA
* writeint (i) write int
* int i; integer to write
* writestring (p) write string
* char *p; string pointer
* writefile (f) write open file
* int f; open file descriptor
*
* COMPLETE MESSAGE (start, one data block, end)
* writemnull (msg) write message with no data
* int msg; message type
* writemint (msg,i) write int message
* int msg; message type
* int i; integer to write
* writemstr (msg,p) write string message
* int msg; message type
* char *p; string pointer
*
* INPUT FROM NETWORK
* MESSAGE START/END
* readflush () flush any unread data (close)
* readmsg (msg) read specified message type
* int msg; message type
* readmend () read message end
*
* MESSAGE DATA
* readskip () skip over one input data block
* readint (i) read int
* int *i; pointer to integer
* readstring (p) read string
* char **p; pointer to string pointer
* readfile (f) read into open file
* int f; open file descriptor
*
* COMPLETE MESSAGE (start, one data block, end)
* readmnull (msg) read message with no data
* int msg; message type
* readmint (msg,i) read int message
* int msg; message type
* int *i; pointer to integer
* readmstr (msg,p) read string message
* int msg; message type
* char **p; pointer to string pointer
*
* RETURN CODES
* All routines normally return SCMOK. SCMERR may be returned
* by any routine on abnormal (usually fatal) errors. An
* unexpected MSGGOAWAY will result in SCMEOF being returned.
*
* COMMUNICATION PROTOCOL
* Messages always alternate, with the first message after
* connecting being sent by the client process.
*
* At the end of the conversation, the client process will
* send a message to the server telling it to go away. Then,
* both processes will close the network connection.
*
* Any time a process receives a message it does not expect,
* the "readmsg" routine will send a MSGGOAWAY message and
* return SCMEOF.
*
* Each message has this format:
* ---------- ------------ ------------ ----------
* |msg type| |count|data| |count|data| ... |ENDCOUNT|
* ---------- ------------ ------------ ----------
* size: int int var. int var. int
*
* All ints are assumed to be 32-bit quantities. A message
* with no data simply has a message type followed by ENDCOUNT.
*
**********************************************************************
* HISTORY
* Revision 1.7 92/09/09 22:04:41 mrt
* Removed the data encryption routines from here to netcrypt.c
* [92/09/09 mrt]
*
* Revision 1.6 92/08/11 12:05:57 mrt
* Brad's changes: Delinted,Added forward declarations of
* static functions. Added copyright.
* [92/07/24 mrt]
*
* 27-Dec-87 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Added crosspatch support.
*
* 28-Jun-87 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Found error in debugging code for readint().
*
* 01-Apr-87 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Added code to readdata to "push" data back into the data buffer.
* Added prereadcount() to return the message count size after
* reading it and then pushing it back into the buffer. Clear
* any encryption when a GOAWAY message is detected before reading
* the reason string. [V5.19]
*
* 02-Oct-86 Rudy Nedved (ern) at Carnegie-Mellon University
* Put a timeout on reading from the network.
*
* 25-May-86 Jonathan J. Chew (jjc) at Carnegie-Mellon University
* Renamed "howmany" parameter to routines "encode" and "decode" from
* to "count" to avoid conflict with 4.3BSD macro.
*
* 15-Feb-86 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Added readflush() to flush any unread data from the input
* buffer. Called by requestend() in scm.c module.
*
* 19-Jan-86 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Added register variables to decode() for speedup. Added I/O
* buffering to reduce the number or read/write system calls.
* Removed readmfil/writemfil routines which were not used and were
* not compatible with the other similarly defined routines anyway.
*
* 19-Dec-85 Glenn Marcy (gm0w) at Carnegie-Mellon University
* Created from scm.c I/O and crypt routines.
*
**********************************************************************
*/
#include "libc.h"
#include <errno.h>
#include <stdbool.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/file.h>
#include <sys/time.h>
#include <sys/poll.h>
#include "supcdefs.h"
#include "supextern.h"
#include "supmsg.h"
#ifndef INFTIM
#define INFTIM -1
#endif
/*************************
*** M A C R O S ***
*************************/
/* end of message */
#define ENDCOUNT (-1) /* end of message marker */
#define NULLCOUNT (-2) /* used for sending NULL pointer */
#define RETRIES 15 /* # of times to retry io */
#define FILEXFER 2048 /* block transfer size */
#define XFERSIZE(count) ((count > FILEXFER) ? FILEXFER : count)
/*********************************************
*** G L O B A L V A R I A B L E S ***
*********************************************/
extern int netfile; /* network file descriptor */
int scmdebug; /* scm debug flag */
int cryptflag; /* whether to encrypt/decrypt data */
char *cryptbuf; /* buffer for data encryption/decryption */
extern char *goawayreason; /* reason for goaway message */
struct buf {
char b_data[FILEXFER]; /* buffered data */
char *b_ptr; /* pointer to end of buffer */
int b_cnt; /* number of bytes in buffer */
} buffers[2];
struct buf *gblbufptr; /* buffer pointer */
static int writedata(size_t, void *);
static int writeblock(size_t, void *);
static int readdata(size_t, void *, bool);
static int readcount(int *);
/***********************************************
*** O U T P U T T O N E T W O R K ***
***********************************************/
static int
writedata(size_t count, void *data)
{ /* write raw data to network */
int x, tries;
struct buf *bp;
if (gblbufptr) {
if (gblbufptr->b_cnt + count <= FILEXFER) {
memcpy(gblbufptr->b_ptr, data, count);
gblbufptr->b_cnt += count;
gblbufptr->b_ptr += count;
return (SCMOK);
}
bp = (gblbufptr == buffers) ? &buffers[1] : buffers;
memcpy(bp->b_data, data, count);
bp->b_cnt = count;
bp->b_ptr = bp->b_data + count;
data = gblbufptr->b_data;
count = gblbufptr->b_cnt;
gblbufptr->b_cnt = 0;
gblbufptr->b_ptr = gblbufptr->b_data;
gblbufptr = bp;
}
tries = 0;
for (;;) {
errno = 0;
x = write(netfile, data, count);
if (x > 0)
break;
if (errno)
break;
if (++tries > RETRIES)
break;
if (scmdebug > 0)
logerr("SCM Retrying failed network write");
}
if (x <= 0) {
if (errno == EPIPE)
return (scmerr(-1, "Network write timed out"));
if (errno)
return (scmerr(errno, "Write error on network"));
return (scmerr(-1, "Write retries failed"));
}
if (x != count)
return (scmerr(-1, "Write error on network returned %d "
"on write of %zu", x, count));
return (SCMOK);
}
static int
writeblock(size_t count, void *data)
{ /* write data block */
int x;
int y = byteswap((int)count);
x = writedata(sizeof(int), &y);
if (x == SCMOK)
x = writedata(count, data);
return (x);
}
int
writemsg(int msg)
{ /* write start of message */
int x;
if (scmdebug > 1)
loginfo("SCM Writing message %d", msg);
if (gblbufptr)
return (scmerr(-1, "Buffering already enabled"));
gblbufptr = buffers;
gblbufptr->b_ptr = gblbufptr->b_data;
gblbufptr->b_cnt = 0;
x = byteswap(msg);
return (writedata(sizeof(int), &x));
}
int
writemend(void)
{ /* write end of message */
size_t count;
char *data;
int x;
x = byteswap(ENDCOUNT);
x = writedata(sizeof(int), &x);
if (x != SCMOK)
return (x);
if (gblbufptr == NULL)
return (scmerr(-1, "Buffering already disabled"));
if (gblbufptr->b_cnt == 0) {
gblbufptr = NULL;
return (SCMOK);
}
data = gblbufptr->b_data;
count = gblbufptr->b_cnt;
gblbufptr = NULL;
return (writedata(count, data));
}
int
writeint(int i)
{ /* write int as data block */
int x;
if (scmdebug > 2)
loginfo("SCM Writing integer %d", i);
x = byteswap(i);
return (writeblock(sizeof(int), &x));
}
int
writestring(char *p)
{ /* write string as data block */
int len;
int x;
if (p == NULL) {
int y = byteswap(NULLCOUNT);
if (scmdebug > 2)
loginfo("SCM Writing string NULL");
return (writedata(sizeof(int), &y));
}
if (scmdebug > 2)
loginfo("SCM Writing string %s", p);
len = strlen(p);
if (cryptflag) {
x = getcryptbuf(len + 1);
if (x != SCMOK)
return (x);
encode(p, cryptbuf, len);
p = cryptbuf;
}
return (writeblock((size_t)len, p));
}
int
writefile(int f)
{ /* write open file as a data block */
char buf[FILEXFER];
int number = 0, sum = 0, filesize, x;
int y;
struct stat statbuf;
if (fstat(f, &statbuf) < 0)
return (scmerr(errno, "Can't access open file for message"));
filesize = (int)statbuf.st_size;
y = byteswap(filesize);
x = writedata(sizeof(int), &y);
if (cryptflag)
x = getcryptbuf(FILEXFER);
if (x == SCMOK) {
sum = 0;
do {
number = read(f, buf, FILEXFER);
if (number > 0) {
if (cryptflag) {
encode(buf, cryptbuf, number);
x = writedata((size_t)number, cryptbuf);
} else {
x = writedata((size_t)number, buf);
}
sum += number;
}
} while (x == SCMOK && number > 0);
}
if (sum != filesize)
return (scmerr(-1, "File size error on output message"));
if (number < 0)
return (scmerr(errno, "Read error on file output message"));
return (x);
}
int
writemnull(int msg)
{ /* write message with no data */
int x;
x = writemsg(msg);
if (x == SCMOK)
x = writemend();
return (x);
}
int
writemint(int msg, int i)
{ /* write message of one int */
int x;
x = writemsg(msg);
if (x == SCMOK)
x = writeint(i);
if (x == SCMOK)
x = writemend();
return (x);
}
int
writemstr(int msg, char *p)
{ /* write message of one string */
int x;
x = writemsg(msg);
if (x == SCMOK)
x = writestring(p);
if (x == SCMOK)
x = writemend();
return (x);
}
/*************************************************
*** I N P U T F R O M N E T W O R K ***
*************************************************/
static int
readdata(size_t count, void *vdata, bool push)
{ /* read raw data from network */
char *p;
int c, n, m, x;
static int bufcnt = 0;
static char *bufptr;
static char buffer[FILEXFER];
struct pollfd set[1];
char *data = vdata;
if (push) {
if (bufptr + count < buffer)
return (scmerr(-1, "No space in buffer %zu", count));
bufptr -= count;
bufcnt += count;
memcpy(bufptr, data, -count);
return (SCMOK);
}
if (count == 0 && data == NULL) {
bufcnt = 0;
return (SCMOK);
}
if (count <= bufcnt) {
memcpy(data, bufptr, count);
bufptr += count;
bufcnt -= count;
return (SCMOK);
}
if (bufcnt > 0) {
memcpy(data, bufptr, (size_t)bufcnt);
data += bufcnt;
count -= bufcnt;
}
bufptr = buffer;
bufcnt = 0;
set[0].fd = netfile;
set[0].events = POLLIN;
p = buffer;
n = FILEXFER;
m = count;
while (m > 0) {
while ((c = poll(set, 1, 2 * 60 * 60 * 1000)) < 1) {
if (c == 0)
return (scmerr(-1, "Timeout on network input"));
if (errno != EINTR)
sleep(5);
}
x = read(netfile, p, (size_t)n);
if (x == 0)
return (scmerr(-1, "Premature EOF on network input"));
if (x < 0)
return (scmerr(errno, "Read error on network"));
p += x;
n -= x;
m -= x;
bufcnt += x;
}
memcpy(data, bufptr, (size_t)count);
bufptr += count;
bufcnt -= count;
return (SCMOK);
}
static int
readcount(int *count)
{ /* read count of data block */
int x;
int y;
x = readdata(sizeof(int), &y, false);
if (x != SCMOK)
return (x);
*count = byteswap(y);
return (SCMOK);
}
int
prereadcount(int *count)
{ /* preread count of data block */
int x;
int y;
x = readdata(sizeof(int), &y, false);
if (x != SCMOK)
return (x);
x = readdata(sizeof(int), &y, true);
if (x != SCMOK)
return (x);
*count = byteswap(y);
return (SCMOK);
}
int
readflush(void)
{
return readdata(0, NULL, false);
}
int
readmsg(int msg)
{ /* read header for expected message */
/* if message is unexpected, send back SCMHUH */
int x;
int m;
if (scmdebug > 1)
loginfo("SCM Reading message %d", msg);
x = readdata(sizeof(int), &m, false); /* msg type */
if (x != SCMOK)
return (x);
m = byteswap(m);
if (m == msg)
return (x);
/* check for MSGGOAWAY in case he noticed problems first */
if (m != MSGGOAWAY)
return (scmerr(-1, "Received unexpected message %d", m));
(void) netcrypt(NULL);
(void) readstring(&goawayreason);
(void) readmend();
if (goawayreason == NULL)
return (SCMEOF);
logerr("SCM GOAWAY for %s %s", remotehost(), goawayreason);
return (SCMEOF);
}
int
readmend(void)
{
int x;
int y;
x = readdata(sizeof(int), &y, false);
y = byteswap(y);
if (x == SCMOK && y != ENDCOUNT)
return (scmerr(-1, "Error reading end of message"));
return (x);
}
int
readskip(void)
{ /* skip over one input block */
int x;
int n;
char buf[FILEXFER];
x = readcount(&n);
if (x != SCMOK)
return (x);
if (n < 0)
return (scmerr(-1, "Invalid message count %d", n));
while (x == SCMOK && n > 0) {
x = readdata((size_t)XFERSIZE(n), buf, false);
n -= XFERSIZE(n);
}
return (x);
}
int
readint(int *buf)
{ /* read int data block */
int x;
int y;
x = readcount(&y);
if (x != SCMOK)
return (x);
if (y < 0)
return (scmerr(-1, "Invalid message count %d", y));
if (y != sizeof(int))
return (scmerr(-1, "Size error for int message is %d", y));
x = readdata(sizeof(int), &y, false);
(*buf) = byteswap(y);
if (scmdebug > 2)
loginfo("SCM Reading integer %d", *buf);
return (x);
}
int
readstring(char **buf)
{ /* read string data block */
int x;
int count;
char *p;
x = readcount(&count);
if (x != SCMOK)
return (x);
if (count == NULLCOUNT) {
if (scmdebug > 2)
loginfo("SCM Reading string NULL");
*buf = NULL;
return (SCMOK);
}
if (count < 0)
return (scmerr(-1, "Invalid message count %d", count));
if (scmdebug > 3)
loginfo("SCM Reading string count %d", count);
if ((p = (char *) malloc((unsigned) count + 1)) == NULL)
return (scmerr(-1, "Can't malloc %d bytes for string", count));
if (cryptflag) {
x = getcryptbuf(count + 1);
if (x == SCMOK)
x = readdata((size_t)count, cryptbuf, false);
if (x != SCMOK) {
free(p);
return (x);
}
if (scmdebug > 3)
printf("SCM Reading encrypted string %s\n", cryptbuf);
decode(cryptbuf, p, count);
} else {
x = readdata((size_t)count, p, false);
if (x != SCMOK) {
free(p);
return (x);
}
}
p[count] = 0; /* NULL at end of string */
*buf = p;
if (scmdebug > 2)
loginfo("SCM Reading string %s", *buf);
return (SCMOK);
}
int
readfile(int f)
{ /* read data block into open file */
int x;
int count;
char buf[FILEXFER];
if (cryptflag) {
x = getcryptbuf(FILEXFER);
if (x != SCMOK)
return (x);
}
x = readcount(&count);
if (x != SCMOK)
return (x);
if (count < 0)
return (scmerr(-1, "Invalid message count %d", count));
while (x == SCMOK && count > 0) {
if (cryptflag) {
x = readdata((size_t)XFERSIZE(count), cryptbuf, false);
if (x == SCMOK)
decode(cryptbuf, buf, XFERSIZE(count));
} else
x = readdata((size_t)XFERSIZE(count), buf, false);
if (x == SCMOK) {
if (write(f, buf, (size_t)XFERSIZE(count)) == -1)
return SCMERR;
count -= XFERSIZE(count);
}
}
return (x);
}
int
readmnull(int msg)
{ /* read null message */
int x;
x = readmsg(msg);
if (x == SCMOK)
x = readmend();
return (x);
}
int
readmint(int msg, int *buf)
{ /* read int message */
int x;
x = readmsg(msg);
if (x == SCMOK)
x = readint(buf);
if (x == SCMOK)
x = readmend();
return (x);
}
int
readmstr(int msg, char **buf)
{ /* read string message */
int x;
x = readmsg(msg);
if (x == SCMOK)
x = readstring(buf);
if (x == SCMOK)
x = readmend();
return (x);
}
/**********************************
*** C R O S S P A T C H ***
**********************************/
void
crosspatch(void)
{
struct pollfd set[2];
int c;
char buf[STRINGLENGTH];
set[0].fd = STDIN_FILENO;
set[0].events = POLLIN;
set[1].fd = netfile;
set[1].events = POLLIN;
for (;;) {
if ((c = poll(set, 2, INFTIM)) < 1) {
if (c == -1) {
if (errno == EINTR) {
continue;
}
}
sleep(5);
continue;
}
if (set[1].revents & POLLIN) {
c = read(netfile, buf, sizeof(buf));
if (c < 0 && errno == EWOULDBLOCK)
c = 0;
else {
if (c <= 0) {
break;
}
if (write(1, buf, (size_t)c) == -1)
break;
}
}
if (set[0].revents & POLLIN) {
c = read(0, buf, sizeof(buf));
if (c < 0 && errno == EWOULDBLOCK)
c = 0;
else {
if (c <= 0)
break;
if (write(netfile, buf, (size_t)c) == -1)
break;
}
}
}
}