completed UDP's locking.

git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@20857 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Hugo Santos 2007-04-27 11:30:03 +00:00
parent 5b1cb74b86
commit 727ad0b0a5

View File

@ -30,6 +30,13 @@
#include <string.h>
#include <utility>
// NOTE the locking protocol dictates that we must hold UdpDomainSupport's
// lock before holding a child UdpEndpoint's lock. This restriction
// is dictated by the receive path as blind access to the endpoint
// hash is required when holding the DomainSuppport's lock.
//#define TRACE_UDP
#ifdef TRACE_UDP
# define TRACE_BLOCK(x) dump_block x
@ -84,14 +91,14 @@ public:
status_t StoreData(net_buffer *buffer);
status_t DeliverData(net_buffer *buffer);
UdpDomainSupport *DomainSupport() const { return fManager; }
// only the domain support will change/check the Active flag so
// we don't really need to protect it with the socket lock.
bool IsActive() const { return fActive; }
void SetActive(bool newValue) { fActive = newValue; }
HashTableLink<UdpEndpoint> *HashTableLink() { return &fLink; }
private:
status_t _Activate();
status_t _Deactivate();
UdpDomainSupport *fManager;
bool fActive;
// an active UdpEndpoint is part of the endpoint
@ -146,39 +153,44 @@ struct UdpHashDefinition {
class UdpDomainSupport : public DoublyLinkedListLinkImpl<UdpDomainSupport> {
public:
UdpDomainSupport(net_domain *domain);
~UdpDomainSupport();
status_t InitCheck() const;
net_domain *Domain() const { return fDomain; }
void Ref() { fEndpointCount++; }
void Put() { fEndpointCount--; }
bool IsEmpty() const { return fEndpointCount == 0; }
bool Put() { fEndpointCount--; return fEndpointCount == 0; }
status_t DemuxIncomingBuffer(net_buffer *buffer);
status_t CheckBindRequest(sockaddr *address, int socketOptions);
status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
status_t UnbindEndpoint(UdpEndpoint *endpoint);
status_t ActivateEndpoint(UdpEndpoint *endpoint);
status_t DeactivateEndpoint(UdpEndpoint *endpoint);
uint16 GetEphemeralPort();
private:
status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address);
status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address);
status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address);
status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address);
UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress,
const sockaddr *peerAddress);
status_t _DemuxBroadcast(net_buffer *buffer);
status_t _DemuxMulticast(net_buffer *buffer);
status_t _DemuxUnicast(net_buffer *buffer);
uint16 _GetNextEphemeral();
UdpEndpoint *_EndpointWithPort(uint16 port) const;
net_address_module_info *AddressModule() const
{
return fDomain->address_module;
}
{ return fDomain->address_module; }
typedef OpenHashTable<UdpHashDefinition, false> EndpointTable;
benaphore fLock;
net_domain *fDomain;
uint16 fLastUsedEphemeral;
EndpointTable fActiveEndpoints;
@ -206,9 +218,6 @@ public:
UdpDomainSupport *OpenEndpoint(UdpEndpoint *endpoint);
status_t FreeEndpoint(UdpDomainSupport *domain);
uint16 GetEphemeralPort();
benaphore *Locker() { return &fLock; }
status_t InitCheck() const;
private:
@ -233,16 +242,27 @@ net_stack_module_info *gStackModule;
UdpDomainSupport::UdpDomainSupport(net_domain *domain)
:
fDomain(domain),
fLastUsedEphemeral(kLast),
fActiveEndpoints(domain->address_module, kNumHashBuckets),
fEndpointCount(0)
{
benaphore_init(&fLock, "udp domain");
fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst);
}
UdpDomainSupport::~UdpDomainSupport()
{
benaphore_destroy(&fLock);
}
status_t
UdpDomainSupport::InitCheck() const
{
if (fLock.sem < B_OK)
return fLock.sem;
return fActiveEndpoints.InitCheck();
}
@ -250,26 +270,96 @@ UdpDomainSupport::InitCheck() const
status_t
UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer)
{
// NOTE multicast is delivered directly to the endpoint
BenaphoreLocker _(fLock);
if (buffer->flags & MSG_BCAST)
return _DemuxBroadcast(buffer);
else if (buffer->flags & MSG_MCAST)
return _DemuxMulticast(buffer);
return B_ERROR;
return _DemuxUnicast(buffer);
}
status_t
UdpDomainSupport::CheckBindRequest(sockaddr *address, int socketOptions)
{ // sUdpEndpointManager->Locker() must be locked!
status_t status = B_OK;
UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint,
const sockaddr *address)
{
BenaphoreLocker _(fLock);
if (endpoint->IsActive())
return EINVAL;
return _BindEndpoint(endpoint, address);
}
status_t
UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint,
const sockaddr *address)
{
BenaphoreLocker _(fLock);
if (endpoint->IsActive()) {
fActiveEndpoints.Remove(endpoint);
endpoint->SetActive(false);
}
if (address->sa_family == AF_UNSPEC) {
// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
// so we reset the peer address:
endpoint->PeerAddress().SetToEmpty();
} else {
status_t status = endpoint->PeerAddress().SetTo(address);
if (status < B_OK)
return status;
}
// we need to activate no matter whether or not we have just disconnected,
// as calling connect() always triggers an implicit bind():
return _BindEndpoint(endpoint, *endpoint->LocalAddress());
}
status_t
UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint)
{
BenaphoreLocker _(fLock);
if (endpoint->IsActive())
fActiveEndpoints.Remove(endpoint);
endpoint->SetActive(false);
return B_OK;
}
status_t
UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint,
const sockaddr *address)
{
if (AddressModule()->get_port(address) == 0)
return _BindToEphemeral(endpoint, address);
return _Bind(endpoint, address);
}
status_t
UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address)
{
int socketOptions = endpoint->Socket()->options;
EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
// Iterate over all active UDP-endpoints and check if the requested bind
// is allowed (see figure 22.24 in [Stevens - TCP2, p735]):
TRACE_DOMAIN("CheckBindRequest() for %s...",
AddressString(fDomain, address, true).Data());
TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain,
address, true).Data());
while (it.HasNext()) {
UdpEndpoint *otherEndpoint = it.Next();
@ -278,56 +368,55 @@ UdpDomainSupport::CheckBindRequest(sockaddr *address, int socketOptions)
if (otherEndpoint->LocalAddress().EqualPorts(address)) {
// port is already bound, SO_REUSEADDR or SO_REUSEPORT is required:
if (otherEndpoint->socket->options & (SO_REUSEADDR | SO_REUSEPORT) == 0
|| socketOptions & (SO_REUSEADDR | SO_REUSEPORT) == 0) {
status = EADDRINUSE;
break;
}
if (otherEndpoint->Socket()->options & (SO_REUSEADDR | SO_REUSEPORT) == 0
|| socketOptions & (SO_REUSEADDR | SO_REUSEPORT) == 0)
return EADDRINUSE;
// if both addresses are the same, SO_REUSEPORT is required:
if (otherEndpoint->LocalAddress().EqualTo(address, false)
&& (otherEndpoint->socket->options & SO_REUSEPORT == 0
|| socketOptions & SO_REUSEPORT == 0)) {
status = EADDRINUSE;
break;
}
&& (otherEndpoint->Socket()->options & SO_REUSEPORT == 0
|| socketOptions & SO_REUSEPORT == 0))
return EADDRINUSE;
}
}
TRACE_DOMAIN(" CheckBindRequest done (status=%lx)", status);
return status;
return _FinishBind(endpoint, address);
}
status_t
UdpDomainSupport::ActivateEndpoint(UdpEndpoint *endpoint)
{ // sUdpEndpointManager->Locker() must be locked!
TRACE_DOMAIN("Endpoint(%s) was activated",
AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint,
const sockaddr *address)
{
SocketAddressStorage newAddress(AddressModule());
status_t status = newAddress.SetTo(address);
if (status < B_OK)
return status;
uint16 allocedPort = _GetNextEphemeral();
if (allocedPort == 0)
return ENOBUFS;
newAddress.SetPort(allocedPort);
return _FinishBind(endpoint, *newAddress);
}
status_t
UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address)
{
status_t status = endpoint->next->module->bind(endpoint->next, address);
if (status < B_OK)
return status;
fActiveEndpoints.Insert(endpoint);
endpoint->SetActive(true);
return B_OK;
}
status_t
UdpDomainSupport::DeactivateEndpoint(UdpEndpoint *endpoint)
{ // sUdpEndpointManager->Locker() must be locked!
TRACE_DOMAIN("Endpoint(%s) was deactivated",
AddressString(fDomain, *endpoint->LocalAddress(), true).Data());
fActiveEndpoints.Remove(endpoint);
return B_OK;
}
uint16
UdpDomainSupport::GetEphemeralPort()
{
return _GetNextEphemeral();
}
UdpEndpoint *
UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress,
const sockaddr *peerAddress)
@ -387,15 +476,6 @@ UdpDomainSupport::_DemuxBroadcast(net_buffer *buffer)
}
status_t
UdpDomainSupport::_DemuxMulticast(net_buffer *buffer)
{ // TODO: implement!
TRACE_DOMAIN("_DemuxMulticast(%p)", buffer);
return B_ERROR;
}
status_t
UdpDomainSupport::_DemuxUnicast(net_buffer *buffer)
{
@ -443,48 +523,38 @@ UdpDomainSupport::_GetNextEphemeral()
curr = kFirst;
}
TRACE_DOMAIN("_GetNextEphemeral()");
// TODO: a free list could be used to avoid the impact of these
// two nested loops most of the time... let's see how bad this really is
bool found = false;
EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu",
fLastUsedEphemeral, curr, stop);
while (!found && curr != stop) {
for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) {
TRACE_DOMAIN(" _GetNextEphemeral(): trying port %hu...", curr);
it.Rewind();
while (!found) {
if (!it.HasNext()) {
found = true;
break;
}
UdpEndpoint *endpoint = it.Next();
uint16 endpointPort = endpoint->LocalAddress().Port();
TRACE_DOMAIN(" _GetNextEphemeral(): checking endpoint %p (port %hu)...",
endpoint, ntohs(endpointPort));
if (endpointPort == htons(curr))
break;
}
if (!found) {
if (curr < kLast)
curr++;
else
curr = kFirst;
if (_EndpointWithPort(htons(curr)) == NULL) {
TRACE_DOMAIN(" _GetNextEphemeral(): ...using port %hu", curr);
fLastUsedEphemeral = curr;
return curr;
}
}
if (!found)
return 0;
return 0;
}
TRACE_DOMAIN(" _GetNextEphemeral(): ...using port %hu", curr);
fLastUsedEphemeral = curr;
return curr;
UdpEndpoint *
UdpDomainSupport::_EndpointWithPort(uint16 port) const
{
EndpointTable::Iterator it = fActiveEndpoints.GetIterator();
while (it.HasNext()) {
UdpEndpoint *endpoint = it.Next();
if (endpoint->LocalAddress().Port() == port)
return endpoint;
}
return NULL;
}
@ -524,9 +594,18 @@ UdpEndpointManager::ReceiveData(net_buffer *buffer)
net_domain *domain = buffer->interface->domain;
BenaphoreLocker _(fLock);
UdpDomainSupport *domainSupport = NULL;
{
BenaphoreLocker _(fLock);
domainSupport = _GetDomain(domain, false);
// TODO we don't want to hold to the manager's lock
// during the whole RX path, we may not hold an
// endpoint's lock with the manager lock held.
// But we should increase the domain's refcount
// here.
}
UdpDomainSupport *domainSupport = _GetDomain(domain, false);
if (domainSupport == NULL) {
// we don't instantiate domain supports in the
// RX path as we are only interested in delivering
@ -619,9 +698,7 @@ UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain)
{
BenaphoreLocker _(fLock);
domain->Put();
if (domain->IsEmpty()) {
if (domain->Put()) {
fDomains.Remove(domain);
delete domain;
}
@ -679,32 +756,7 @@ status_t
UdpEndpoint::Bind(const sockaddr *address)
{
TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data());
// let the underlying protocol check whether there is an interface that
// supports the given address
status_t status = next->module->bind(next, address);
if (status < B_OK)
return status;
BenaphoreLocker locker(sUdpEndpointManager->Locker());
if (fActive) {
// socket module should have called unbind() before!
return EINVAL;
}
if (AddressModule()->get_port(address) == 0) {
uint16 port = htons(fManager->GetEphemeralPort());
if (port == 0)
return ENOBUFS;
LocalAddress().SetPort(port);
} else {
status = fManager->CheckBindRequest(*LocalAddress(), socket->options);
if (status < B_OK)
return status;
}
return _Activate();
return fManager->BindEndpoint(this, address);
}
@ -712,10 +764,7 @@ status_t
UdpEndpoint::Unbind(sockaddr *address)
{
TRACE_EP("Unbind()");
BenaphoreLocker locker(sUdpEndpointManager->Locker());
return _Deactivate();
return fManager->UnbindEndpoint(this);
}
@ -723,24 +772,7 @@ status_t
UdpEndpoint::Connect(const sockaddr *address)
{
TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data());
BenaphoreLocker locker(sUdpEndpointManager->Locker());
if (fActive)
_Deactivate();
if (address->sa_family == AF_UNSPEC) {
// [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect",
// so we reset the peer address:
PeerAddress().SetToEmpty();
} else {
// TODO check if `address' is compatible with AddressModule()
PeerAddress().SetTo(address);
}
// we need to activate no matter whether or not we have just disconnected,
// as calling connect() always triggers an implicit bind():
return _Activate();
return fManager->ConnectEndpoint(this, address);
}
@ -749,6 +781,8 @@ UdpEndpoint::Open()
{
TRACE_EP("Open()");
BenaphoreLocker _(fLock);
status_t status = ProtocolSocket::Open();
if (status < B_OK)
return status;
@ -765,11 +799,6 @@ status_t
UdpEndpoint::Close()
{
TRACE_EP("Close()");
BenaphoreLocker _(sUdpEndpointManager->Locker());
if (fActive)
_Deactivate();
return B_OK;
}
@ -778,32 +807,10 @@ status_t
UdpEndpoint::Free()
{
TRACE_EP("Free()");
return sUdpEndpointManager->FreeEndpoint(fManager);
}
status_t
UdpEndpoint::_Activate()
{
if (fActive)
return B_ERROR;
status_t status = fManager->ActivateEndpoint(this);
fActive = (status == B_OK);
return status;
}
status_t
UdpEndpoint::_Deactivate()
{
if (!fActive)
return B_ERROR;
fActive = false;
return fManager->DeactivateEndpoint(this);
}
// #pragma mark - outbound