tcp: Replace custom WaitList with ConditionVariable.

The WaitList implementation had a race condition between checking for
the condition and acquiering the semaphore. If a thread was rescheduled
at that point, the signal could be missed due to the use of
release_sem_etc() with the B_RELEASE_ALL flag while the thread was not
yet waiting for the semaphore. The transfer would subsequently stall.
This commit is contained in:
Michael Lotz 2015-08-02 15:10:06 +02:00
parent 2fdea65c3a
commit da8fbe0e59
2 changed files with 34 additions and 91 deletions

View File

@ -364,68 +364,10 @@ state_needs_finish(int32 state)
// #pragma mark -
WaitList::WaitList(const char* name)
{
fCondition = 0;
fSem = create_sem(0, name);
}
WaitList::~WaitList()
{
delete_sem(fSem);
}
status_t
WaitList::InitCheck() const
{
return fSem;
}
status_t
WaitList::Wait(MutexLocker& locker, bigtime_t timeout)
{
locker.Unlock();
status_t status = B_OK;
while (!atomic_test_and_set(&fCondition, 0, 1)) {
status = acquire_sem_etc(fSem, 1, B_ABSOLUTE_TIMEOUT | B_CAN_INTERRUPT,
timeout);
if (status != B_OK)
break;
}
locker.Lock();
return status;
}
void
WaitList::Signal()
{
atomic_or(&fCondition, 1);
#ifdef __HAIKU__
release_sem_etc(fSem, 1, B_DO_NOT_RESCHEDULE | B_RELEASE_ALL);
#else
int32 count;
if (get_sem_count(fSem, &count) == B_OK && count < 0)
release_sem_etc(fSem, -count, B_DO_NOT_RESCHEDULE);
#endif
}
// #pragma mark -
TCPEndpoint::TCPEndpoint(net_socket* socket)
:
ProtocolSocket(socket),
fManager(NULL),
fReceiveList("tcp receive"),
fSendList("tcp send"),
fOptions(0),
fSendWindowShift(0),
fReceiveWindowShift(0),
@ -457,6 +399,9 @@ TCPEndpoint::TCPEndpoint(net_socket* socket)
// TODO: to be replaced with a real read/write locking strategy!
mutex_init(&fLock, "tcp lock");
fReceiveCondition.Init(this, "tcp receive");
fSendCondition.Init(this, "tcp send");
gStackModule->init_timer(&fPersistTimer, TCPEndpoint::_PersistTimer, this);
gStackModule->init_timer(&fRetransmitTimer, TCPEndpoint::_RetransmitTimer,
this);
@ -494,12 +439,6 @@ TCPEndpoint::~TCPEndpoint()
status_t
TCPEndpoint::InitCheck() const
{
if (fReceiveList.InitCheck() < B_OK)
return fReceiveList.InitCheck();
if (fSendList.InitCheck() < B_OK)
return fSendList.InitCheck();
return B_OK;
}
@ -551,7 +490,7 @@ TCPEndpoint::Close()
bigtime_t maximum = absolute_timeout(socket->linger * 1000000LL);
while (fSendQueue.Used() > 0) {
status = fSendList.Wait(locker, maximum);
status = _WaitForCondition(fSendCondition, locker, maximum);
if (status == B_TIMED_OUT || status == B_WOULD_BLOCK)
break;
else if (status < B_OK)
@ -817,7 +756,7 @@ TCPEndpoint::SendData(net_buffer *buffer)
while (left > 0) {
while (fSendQueue.Free() < socket->send.low_water_mark) {
// wait until enough space is available
status_t status = fSendList.Wait(lock, timeout);
status_t status = _WaitForCondition(fSendCondition, lock, timeout);
if (status < B_OK) {
TRACE(" SendData() returning %s (%d)",
strerror(posix_error(status)), (int)posix_error(status));
@ -969,7 +908,7 @@ TCPEndpoint::ReadData(size_t numBytes, uint32 flags, net_buffer** _buffer)
if ((fFlags & FLAG_NO_RECEIVE) != 0)
return B_OK;
status_t status = fReceiveList.Wait(locker, timeout);
status_t status = _WaitForCondition(fReceiveCondition, locker, timeout);
if (status < B_OK) {
// The Open Group base specification mentions that EINTR should be
// returned if the recv() is interrupted before _any data_ is
@ -987,7 +926,7 @@ TCPEndpoint::ReadData(size_t numBytes, uint32 flags, net_buffer** _buffer)
fReceiveQueue.Available());
if (numBytes < fReceiveQueue.Available())
fReceiveList.Signal();
fReceiveCondition.NotifyAll();
bool clone = (flags & MSG_PEEK) != 0;
@ -1198,7 +1137,7 @@ TCPEndpoint::_MarkEstablished()
release_sem_etc(fAcceptSemaphore, 1, B_DO_NOT_RESCHEDULE);
}
fSendList.Signal();
fSendCondition.NotifyAll();
gSocketModule->notify(socket, B_SELECT_WRITE, fSendQueue.Free());
}
@ -1212,7 +1151,7 @@ TCPEndpoint::_WaitForEstablished(MutexLocker &locker, bigtime_t timeout)
if (socket->error != B_OK)
return socket->error;
status_t status = fSendList.Wait(locker, timeout);
status_t status = _WaitForCondition(fSendCondition, locker, timeout);
if (status < B_OK)
return status;
}
@ -1233,7 +1172,7 @@ TCPEndpoint::_Close()
fFlags |= FLAG_DELETE_ON_CLOSE;
fSendList.Signal();
fSendCondition.NotifyAll();
_NotifyReader();
if (gSocketModule->has_parent(socket)) {
@ -1311,7 +1250,7 @@ TCPEndpoint::_AvailableData() const
void
TCPEndpoint::_NotifyReader()
{
fReceiveList.Signal();
fReceiveCondition.NotifyAll();
gSocketModule->notify(socket, B_SELECT_READ, _AvailableData());
}
@ -2203,7 +2142,7 @@ TCPEndpoint::_Acknowledged(tcp_segment_header& segment)
if (is_writable(fState)) {
// notify threads waiting on the socket to become writable again
fSendList.Signal();
fSendCondition.NotifyAll();
gSocketModule->notify(socket, B_SELECT_WRITE, fSendQueue.Free());
}
@ -2342,6 +2281,21 @@ TCPEndpoint::_TimeWaitTimer(net_timer* timer, void* _endpoint)
}
/*static*/ status_t
TCPEndpoint::_WaitForCondition(ConditionVariable& condition,
MutexLocker& locker, bigtime_t timeout)
{
ConditionVariableEntry entry;
condition.Add(&entry);
locker.Unlock();
status_t result = entry.Wait(B_ABSOLUTE_TIMEOUT | B_CAN_INTERRUPT, timeout);
locker.Lock();
return result;
}
// #pragma mark -

View File

@ -25,22 +25,6 @@
#include <stddef.h>
class WaitList {
public:
WaitList(const char* name);
~WaitList();
status_t InitCheck() const;
status_t Wait(MutexLocker& locker, bigtime_t timeout = B_INFINITE_TIMEOUT);
void Signal();
private:
int32 fCondition;
sem_id fSem;
};
class TCPEndpoint : public net_protocol, public ProtocolSocket {
public:
TCPEndpoint(net_socket* socket);
@ -132,6 +116,9 @@ private:
static void _DelayedAcknowledgeTimer(net_timer* timer,
void* _endpoint);
static status_t _WaitForCondition(ConditionVariable& condition,
MutexLocker& locker, bigtime_t timeout);
private:
TCPEndpoint* fConnectionHashLink;
TCPEndpoint* fEndpointHashLink;
@ -141,8 +128,10 @@ private:
mutex fLock;
EndpointManager* fManager;
WaitList fReceiveList;
WaitList fSendList;
ConditionVariable
fReceiveCondition;
ConditionVariable
fSendCondition;
sem_id fAcceptSemaphore;
uint8 fOptions;