assorted TCP fixes.

- fixed the locking for spawned connections and accept()s.
 - return EMSGSIZE if the user is trying to write more data than the send buffer can hold.
 - fixed a crash when receiving a RST while the connection is being closed.
 - don't wake up readers when the connection gets established.
 - endpoint managers lock must be recursive to properly work with spawn'ed sockets.


git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@20727 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Hugo Santos 2007-04-16 18:32:49 +00:00
parent f56b0aa00d
commit 4ee088419f
5 changed files with 168 additions and 142 deletions

View File

@ -45,6 +45,7 @@ class BufferQueue {
size_t Used() const { return fNumBytes; }
size_t Free() const { return fMaxBytes - fNumBytes; }
size_t Size() const { return fMaxBytes; }
bool IsContiguous() const { return fNumBytes == fContiguousBytes; }

View File

@ -256,6 +256,7 @@ status_t
TCPEndpoint::Close()
{
TRACE("Close()");
RecursiveLocker lock(fLock);
if (fState == LISTEN)
@ -270,8 +271,6 @@ TCPEndpoint::Close()
if (status != B_OK)
return status;
TRACE("Close() after Shutdown()");
if (socket->options & SO_LINGER) {
TRACE("Close(): Lingering for %i secs", socket->linger);
@ -299,6 +298,8 @@ TCPEndpoint::Free()
{
TRACE("Free()");
RecursiveLocker _(fLock);
if (fState <= SYNCHRONIZE_SENT || fState == TIME_WAIT)
return B_OK;
@ -319,7 +320,7 @@ TCPEndpoint::Connect(const struct sockaddr *address)
TRACE("Connect() on address %s",
AddressString(Domain(), address, true).Data());
RecursiveLocker locker(&fLock);
RecursiveLocker locker(fLock);
// Can only call connect() from CLOSED or LISTEN states
// otherwise endpoint is considered already connected
@ -391,11 +392,8 @@ TCPEndpoint::Connect(const struct sockaddr *address)
return EINPROGRESS;
}
while (status == B_OK && fState != ESTABLISHED)
status = fSendList.Wait(locker, absolute_timeout(timeout));
status = _WaitForEstablished(locker, absolute_timeout(timeout));
TRACE(" Connect(): Connection complete: %s", strerror(status));
return posix_error(status);
}
@ -405,15 +403,20 @@ TCPEndpoint::Accept(struct net_socket **_acceptedSocket)
{
TRACE("Accept()");
RecursiveLocker locker(fLock);
status_t status;
bigtime_t timeout = absolute_timeout(socket->receive.timeout);
do {
status = acquire_sem_etc(fAcceptSemaphore, 1, B_ABSOLUTE_TIMEOUT |
B_CAN_INTERRUPT, timeout);
locker.Unlock();
status = acquire_sem_etc(fAcceptSemaphore, 1, B_ABSOLUTE_TIMEOUT
| B_CAN_INTERRUPT, timeout);
if (status < B_OK)
return status;
locker.Lock();
status = gSocketModule->dequeue_connected(socket, _acceptedSocket);
if (status == B_OK)
TRACE(" Accept() returning %p", (*_acceptedSocket)->first_protocol);
@ -429,11 +432,11 @@ TCPEndpoint::Bind(sockaddr *address)
if (address == NULL)
return B_BAD_VALUE;
RecursiveLocker lock(fLock);
TRACE("Bind() on address %s",
AddressString(Domain(), address, true).Data());
RecursiveLocker lock(fLock);
if (fState != CLOSED)
return EISCONN;
@ -478,6 +481,9 @@ TCPEndpoint::Listen(int count)
return EDESTADDRREQ;
fAcceptSemaphore = create_sem(0, "tcp accept");
if (fAcceptSemaphore < B_OK)
return ENOBUFS;
fState = LISTEN;
return B_OK;
}
@ -524,6 +530,9 @@ TCPEndpoint::SendData(net_buffer *buffer)
}
if (buffer->size > 0) {
if (buffer->size > fSendQueue.Size())
return EMSGSIZE;
bigtime_t timeout = absolute_timeout(socket->send.timeout);
while (fSendQueue.Free() < buffer->size) {
@ -579,12 +588,9 @@ TCPEndpoint::ReadData(size_t numBytes, uint32 flags, net_buffer** _buffer)
if (flags & MSG_DONTWAIT)
return B_WOULD_BLOCK;
while (fState != ESTABLISHED) {
// we need to wait until the connection becomes established
status_t status = fSendList.Wait(locker, timeout);
if (status < B_OK)
return posix_error(status);
}
status_t status = _WaitForEstablished(locker, timeout);
if (status < B_OK)
return posix_error(status);
}
size_t dataNeeded = socket->receive.low_water_mark;
@ -727,10 +733,9 @@ TCPEndpoint::UpdateTimeWait()
int32
TCPEndpoint::ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
TCPEndpoint::_ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
{
TRACE("ListenReceive(): packet %p (%lu bytes) with flags 0x%x, seq %lu, ack %lu!",
buffer, buffer->size, segment.flags, segment.sequence, segment.acknowledge);
TRACE("ListenReceive()");
// Essentially, we accept only TCP_FLAG_SYNCHRONIZE in this state,
// but the error behaviour differs
@ -753,60 +758,71 @@ TCPEndpoint::ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
AddressModule()->set_to((sockaddr *)&newSocket->peer,
(sockaddr *)&buffer->source);
TCPEndpoint *endpoint = (TCPEndpoint *)newSocket->first_protocol;
return ((TCPEndpoint *)newSocket->first_protocol)->Spawn(this, segment, buffer);
}
endpoint->fSpawned = true;
int32
TCPEndpoint::Spawn(TCPEndpoint *parent, tcp_segment_header &segment,
net_buffer *buffer)
{
RecursiveLocker _(fLock);
fState = SYNCHRONIZE_RECEIVED;
fManager = parent->fManager;
TRACE("Spawn()");
fSpawned = true;
sockaddr *local = (sockaddr *)&socket->address;
sockaddr *peer = (sockaddr *)&socket->peer;
// TODO: proper error handling!
endpoint->fRoute = gDatalinkModule->get_route(Domain(),
(sockaddr *)&newSocket->peer);
if (endpoint->fRoute == NULL)
fRoute = gDatalinkModule->get_route(Domain(), peer);
if (fRoute == NULL)
return DROP;
if (fManager->SetConnection(endpoint, (sockaddr *)&buffer->destination,
(sockaddr *)&buffer->source, NULL) < B_OK)
if (fManager->SetConnection(this, local, peer, NULL) < B_OK)
return DROP;
endpoint->fInitialReceiveSequence = segment.sequence;
endpoint->fReceiveQueue.SetInitialSequence(segment.sequence + 1);
endpoint->fState = SYNCHRONIZE_RECEIVED;
endpoint->fAcceptSemaphore = fAcceptSemaphore;
endpoint->fReceiveMaxSegmentSize = _GetMSS((sockaddr *)&newSocket->peer);
fInitialReceiveSequence = segment.sequence;
fReceiveQueue.SetInitialSequence(segment.sequence + 1);
fAcceptSemaphore = parent->fAcceptSemaphore;
fReceiveMaxSegmentSize = _GetMSS(peer);
// 40 bytes for IP and TCP header without any options
// TODO: make this depending on the RTF_LOCAL flag?
endpoint->fReceiveNext = segment.sequence + 1;
fReceiveNext = segment.sequence + 1;
// account for the extra sequence number for the synchronization
endpoint->fInitialSendSequence = system_time() >> 4;
endpoint->fSendNext = endpoint->fInitialSendSequence;
endpoint->fSendUnacknowledged = endpoint->fSendNext;
endpoint->fSendMax = endpoint->fSendNext;
fInitialSendSequence = system_time() >> 4;
fSendNext = fInitialSendSequence;
fSendUnacknowledged = fSendNext;
fSendMax = fSendNext;
// set options
if ((fOptions & TCP_NOOPT) == 0) {
if ((parent->fOptions & TCP_NOOPT) == 0) {
if (segment.max_segment_size > 0)
endpoint->fSendMaxSegmentSize = segment.max_segment_size;
fSendMaxSegmentSize = segment.max_segment_size;
else
endpoint->fReceiveMaxSegmentSize = TCP_DEFAULT_MAX_SEGMENT_SIZE;
fReceiveMaxSegmentSize = TCP_DEFAULT_MAX_SEGMENT_SIZE;
if (segment.has_window_shift) {
endpoint->fFlags |= FLAG_OPTION_WINDOW_SHIFT;
endpoint->fSendWindowShift = segment.window_shift;
fFlags |= FLAG_OPTION_WINDOW_SHIFT;
fSendWindowShift = segment.window_shift;
} else {
endpoint->fFlags &= ~FLAG_OPTION_WINDOW_SHIFT;
endpoint->fReceiveWindowShift = 0;
fFlags &= ~FLAG_OPTION_WINDOW_SHIFT;
fReceiveWindowShift = 0;
}
}
TRACE(" ListenReceive() created new endpoint %p", endpoint);
endpoint->_UpdateTimestamps(segment, 0, false);
_UpdateTimestamps(segment, 0, false);
// send SYN+ACK
status_t status = endpoint->_SendQueued();
status_t status = _SendQueued();
endpoint->fInitialSendSequence = endpoint->fSendNext;
endpoint->fSendQueue.SetInitialSequence(endpoint->fSendNext);
fInitialSendSequence = fSendNext;
fSendQueue.SetInitialSequence(fSendNext);
if (status < B_OK)
return DROP;
@ -814,16 +830,14 @@ TCPEndpoint::ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
segment.flags &= ~TCP_FLAG_SYNCHRONIZE;
// we handled this flag now, it must not be set for further processing
return endpoint->_Receive(segment, buffer);
// TODO: here, the ack/delayed ack call will be made on the parent socket!
return _Receive(segment, buffer);
}
int32
TCPEndpoint::SynchronizeSentReceive(tcp_segment_header &segment, net_buffer *buffer)
TCPEndpoint::_SynchronizeSentReceive(tcp_segment_header &segment, net_buffer *buffer)
{
TRACE("SynchronizeSentReceive(): packet %p (%lu bytes) with flags 0x%x, seq %lu, ack %lu!",
buffer, buffer->size, segment.flags, segment.sequence, segment.acknowledge);
TRACE("SynchronizeSentReceive()");
if ((segment.flags & TCP_FLAG_ACKNOWLEDGE) != 0
&& (fInitialSendSequence >= segment.acknowledge
@ -859,16 +873,7 @@ TCPEndpoint::SynchronizeSentReceive(tcp_segment_header &segment, net_buffer *buf
}
if (segment.flags & TCP_FLAG_ACKNOWLEDGE) {
// the connection has been established
fState = ESTABLISHED;
if (socket->parent != NULL) {
gSocketModule->set_connected(socket);
release_sem_etc(fAcceptSemaphore, 1, B_DO_NOT_RESCHEDULE);
}
fSendList.Signal();
_NotifyReader();
_MarkEstablished();
} else {
// simultaneous open
fState = SYNCHRONIZE_RECEIVED;
@ -884,13 +889,50 @@ TCPEndpoint::SynchronizeSentReceive(tcp_segment_header &segment, net_buffer *buf
int32
TCPEndpoint::Receive(tcp_segment_header &segment, net_buffer *buffer)
TCPEndpoint::SegmentReceived(tcp_segment_header &segment, net_buffer *buffer)
{
TRACE("Receive(): packet %p (%lu bytes) with flags 0x%x, seq %lu, ack %lu!",
buffer, buffer->size, segment.flags, segment.sequence, segment.acknowledge);
RecursiveLocker locker(fLock);
// TODO: rethink locking!
TRACE("SegmentReceived(): packet %p (%lu bytes) with flags 0x%x, seq %lu, "
"ack %lu", buffer, buffer->size, segment.flags, segment.sequence,
segment.acknowledge);
int32 segmentAction = DROP;
switch (fState) {
case LISTEN:
segmentAction = _ListenReceive(segment, buffer);
break;
case SYNCHRONIZE_SENT:
segmentAction = _SynchronizeSentReceive(segment, buffer);
break;
case SYNCHRONIZE_RECEIVED:
case ESTABLISHED:
case FINISH_RECEIVED:
case WAIT_FOR_FINISH_ACKNOWLEDGE:
case FINISH_SENT:
case FINISH_ACKNOWLEDGED:
case CLOSING:
case TIME_WAIT:
case CLOSED:
segmentAction = _SegmentReceived(segment, buffer);
break;
}
// process acknowledge action as asked for by the *Receive() method
if (segmentAction & IMMEDIATE_ACKNOWLEDGE)
SendAcknowledge();
else if (segmentAction & ACKNOWLEDGE)
DelayedAcknowledge();
return segmentAction;
}
int32
TCPEndpoint::_SegmentReceived(tcp_segment_header &segment, net_buffer *buffer)
{
uint32 advertisedWindow = (uint32)segment.advertised_window << fSendWindowShift;
// First, handle the most common case for uni-directional data transfer
@ -1287,23 +1329,15 @@ TCPEndpoint::_Receive(tcp_segment_header &segment, net_buffer *buffer)
if (fLastAcknowledgeSent <= segment.sequence
&& tcp_sequence(segment.sequence)
< (fLastAcknowledgeSent + fReceiveWindow)) {
if (fState == SYNCHRONIZE_RECEIVED) {
// TODO: if we came from SYN-SENT signal connection refused
// and remove all segments from tx queue
} else if (fState == ESTABLISHED || fState == FINISH_SENT
|| fState == FINISH_RECEIVED || fState == FINISH_ACKNOWLEDGED) {
// TODO: RFC 793 states that on ESTABLISHED, FIN-WAIT{1,2}
// or CLOSE-WAIT "All segment queues should be
// flushed".
}
if (fState == SYNCHRONIZE_RECEIVED)
fError = ECONNREFUSED;
else if (fState == CLOSING || fState == TIME_WAIT
|| fState == WAIT_FOR_FINISH_ACKNOWLEDGE)
fError = ENOTCONN;
else
fError = ECONNRESET;
if (fState != TIME_WAIT && fReceiveQueue.Available() > 0) {
_NotifyReader();
} else {
return DELETE | DROP;
}
fError = ECONNREFUSED;
_NotifyReader();
fState = CLOSED;
}
@ -1371,15 +1405,7 @@ TCPEndpoint::_Receive(tcp_segment_header &segment, net_buffer *buffer)
// process acknowledged data
if (fState == SYNCHRONIZE_RECEIVED) {
// TODO: window scaling!
if (socket->parent != NULL) {
gSocketModule->set_connected(socket);
release_sem_etc(fAcceptSemaphore, 1, B_DO_NOT_RESCHEDULE);
}
fState = ESTABLISHED;
fSendList.Signal();
_NotifyReader();
_MarkEstablished();
}
if (fSendMax < segment.acknowledge || fState == TIME_WAIT)
@ -1536,6 +1562,33 @@ TCPEndpoint::_UpdateTimestamps(tcp_segment_header &segment, size_t segmentLength
}
void
TCPEndpoint::_MarkEstablished()
{
fState = ESTABLISHED;
if (socket->parent != NULL) {
gSocketModule->set_connected(socket);
release_sem_etc(fAcceptSemaphore, 1, B_DO_NOT_RESCHEDULE);
}
fSendList.Signal();
}
status_t
TCPEndpoint::_WaitForEstablished(RecursiveLocker &locker, bigtime_t timeout)
{
while (fState != ESTABLISHED) {
status_t status = fSendList.Wait(locker, timeout);
if (status < B_OK)
return status;
}
return B_OK;
}
// #pragma mark - timer
@ -1544,7 +1597,7 @@ TCPEndpoint::_RetransmitTimer(net_timer *timer, void *data)
{
TCPEndpoint *endpoint = (TCPEndpoint *)data;
RecursiveLocker locker(endpoint->Lock());
RecursiveLocker locker(endpoint->fLock);
if (!locker.IsLocked())
return;
@ -1559,7 +1612,7 @@ TCPEndpoint::_PersistTimer(net_timer *timer, void *data)
{
TCPEndpoint *endpoint = (TCPEndpoint *)data;
RecursiveLocker locker(endpoint->Lock());
RecursiveLocker locker(endpoint->fLock);
if (!locker.IsLocked())
return;
@ -1572,7 +1625,7 @@ TCPEndpoint::_DelayedAcknowledgeTimer(struct net_timer *timer, void *data)
{
TCPEndpoint *endpoint = (TCPEndpoint *)data;
RecursiveLocker locker(endpoint->Lock());
RecursiveLocker locker(endpoint->fLock);
if (!locker.IsLocked())
return;
@ -1585,7 +1638,7 @@ TCPEndpoint::_TimeWaitTimer(struct net_timer *timer, void *data)
{
TCPEndpoint *endpoint = (TCPEndpoint *)data;
if (recursive_lock_lock(&endpoint->Lock()) < B_OK)
if (recursive_lock_lock(&endpoint->fLock) < B_OK)
return;
endpoint->DeleteSocket();

View File

@ -45,8 +45,6 @@ class TCPEndpoint : public net_protocol {
status_t InitCheck() const;
recursive_lock &Lock() { return fLock; }
status_t Open();
status_t Close();
status_t Free();
@ -69,10 +67,10 @@ class TCPEndpoint : public net_protocol {
status_t DelayedAcknowledge();
status_t SendAcknowledge();
status_t UpdateTimeWait();
int32 ListenReceive(tcp_segment_header& segment, net_buffer *buffer);
int32 SynchronizeSentReceive(tcp_segment_header& segment,
int32 SegmentReceived(tcp_segment_header& segment, net_buffer *buffer);
int32 Spawn(TCPEndpoint *parent, tcp_segment_header& segment,
net_buffer *buffer);
int32 Receive(tcp_segment_header& segment, net_buffer *buffer);
net_domain *Domain() const
{ return socket->first_protocol->module->get_domain(
@ -94,9 +92,15 @@ class TCPEndpoint : public net_protocol {
ssize_t _AvailableData() const;
void _NotifyReader();
bool _ShouldReceive() const;
int32 _ListenReceive(tcp_segment_header& segment, net_buffer *buffer);
int32 _SynchronizeSentReceive(tcp_segment_header& segment,
net_buffer *buffer);
int32 _SegmentReceived(tcp_segment_header& segment, net_buffer *buffer);
int32 _Receive(tcp_segment_header& segment, net_buffer *buffer);
void _UpdateTimestamps(tcp_segment_header& segment,
size_t segmentLength, bool checkSequence);
void _MarkEstablished();
status_t _WaitForEstablished(RecursiveLocker &lock, bigtime_t timeout);
static void _TimeWaitTimer(net_timer *timer, void *data);
static void _RetransmitTimer(net_timer *timer, void *data);

View File

@ -53,7 +53,7 @@ net_stack_module_info *gStackModule;
// protocol cookie, so we don't have to go through the list
// for each segment.
typedef DoublyLinkedList<EndpointManager> EndpointManagerList;
static benaphore sEndpointManagersLock;
static recursive_lock sEndpointManagersLock;
static EndpointManagerList sEndpointManagers;
@ -536,7 +536,7 @@ tcp_receive_data(net_buffer *buffer)
bufferHeader.Remove(headerLength);
// we no longer need to keep the header around
BenaphoreLocker _(sEndpointManagersLock);
RecursiveLocker _(sEndpointManagersLock);
EndpointManager *endpointManager = endpoint_manager_for(domain);
if (endpointManager == NULL)
@ -547,40 +547,9 @@ tcp_receive_data(net_buffer *buffer)
TCPEndpoint *endpoint = endpointManager->FindConnection(
(sockaddr *)&buffer->destination, (sockaddr *)&buffer->source);
if (endpoint != NULL) {
RecursiveLocker locker(endpoint->Lock());
TRACE(("Endpoint %p in state %s\n", endpoint, name_for_state(endpoint->State())));
switch (endpoint->State()) {
case LISTEN:
segmentAction = endpoint->ListenReceive(segment, buffer);
break;
case SYNCHRONIZE_SENT:
segmentAction = endpoint->SynchronizeSentReceive(segment, buffer);
break;
case SYNCHRONIZE_RECEIVED:
case ESTABLISHED:
case FINISH_RECEIVED:
case WAIT_FOR_FINISH_ACKNOWLEDGE:
case FINISH_SENT:
case FINISH_ACKNOWLEDGED:
case CLOSING:
case TIME_WAIT:
case CLOSED:
segmentAction = endpoint->Receive(segment, buffer);
break;
}
// process acknowledge action as asked for by the *Receive() method
if (segmentAction & IMMEDIATE_ACKNOWLEDGE)
endpoint->SendAcknowledge();
else if (segmentAction & ACKNOWLEDGE)
endpoint->DelayedAcknowledge();
else if (segmentAction & DELETE)
endpoint->DeleteSocket();
} else if ((segment.flags & TCP_FLAG_RESET) == 0)
if (endpoint != NULL)
segmentAction = endpoint->SegmentReceived(segment, buffer);
else if ((segment.flags & TCP_FLAG_RESET) == 0)
segmentAction = DROP | RESET;
if (segmentAction & RESET) {
@ -615,7 +584,7 @@ tcp_error_reply(net_protocol *protocol, net_buffer *causedError, uint32 code,
static status_t
tcp_init()
{
status_t status = benaphore_init(&sEndpointManagersLock,
status_t status = recursive_lock_init(&sEndpointManagersLock,
"endpoint managers lock");
if (status < B_OK)
@ -647,7 +616,7 @@ tcp_init()
static status_t
tcp_uninit()
{
benaphore_destroy(&sEndpointManagersLock);
recursive_lock_destroy(&sEndpointManagersLock);
return B_OK;
}

View File

@ -165,7 +165,6 @@ enum tcp_segment_action {
RESET = 0x02,
ACKNOWLEDGE = 0x04,
IMMEDIATE_ACKNOWLEDGE = 0x08,
DELETE = 0x10,
};