diff --git a/src/add-ons/kernel/network/protocols/udp/udp.cpp b/src/add-ons/kernel/network/protocols/udp/udp.cpp index fdfd4cf709..1c04dddaf2 100644 --- a/src/add-ons/kernel/network/protocols/udp/udp.cpp +++ b/src/add-ons/kernel/network/protocols/udp/udp.cpp @@ -30,6 +30,13 @@ #include #include + +// NOTE the locking protocol dictates that we must hold UdpDomainSupport's +// lock before holding a child UdpEndpoint's lock. This restriction +// is dictated by the receive path as blind access to the endpoint +// hash is required when holding the DomainSuppport's lock. + + //#define TRACE_UDP #ifdef TRACE_UDP # define TRACE_BLOCK(x) dump_block x @@ -84,14 +91,14 @@ public: status_t StoreData(net_buffer *buffer); status_t DeliverData(net_buffer *buffer); - UdpDomainSupport *DomainSupport() const { return fManager; } + // only the domain support will change/check the Active flag so + // we don't really need to protect it with the socket lock. + bool IsActive() const { return fActive; } + void SetActive(bool newValue) { fActive = newValue; } HashTableLink *HashTableLink() { return &fLink; } private: - status_t _Activate(); - status_t _Deactivate(); - UdpDomainSupport *fManager; bool fActive; // an active UdpEndpoint is part of the endpoint @@ -146,39 +153,44 @@ struct UdpHashDefinition { class UdpDomainSupport : public DoublyLinkedListLinkImpl { public: UdpDomainSupport(net_domain *domain); + ~UdpDomainSupport(); status_t InitCheck() const; net_domain *Domain() const { return fDomain; } void Ref() { fEndpointCount++; } - void Put() { fEndpointCount--; } - - bool IsEmpty() const { return fEndpointCount == 0; } + bool Put() { fEndpointCount--; return fEndpointCount == 0; } status_t DemuxIncomingBuffer(net_buffer *buffer); - status_t CheckBindRequest(sockaddr *address, int socketOptions); + + status_t BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address); + status_t ConnectEndpoint(UdpEndpoint *endpoint, const sockaddr *address); + status_t UnbindEndpoint(UdpEndpoint *endpoint); + status_t ActivateEndpoint(UdpEndpoint *endpoint); status_t DeactivateEndpoint(UdpEndpoint *endpoint); - uint16 GetEphemeralPort(); - private: + status_t _BindEndpoint(UdpEndpoint *endpoint, const sockaddr *address); + status_t _Bind(UdpEndpoint *endpoint, const sockaddr *address); + status_t _BindToEphemeral(UdpEndpoint *endpoint, const sockaddr *address); + status_t _FinishBind(UdpEndpoint *endpoint, const sockaddr *address); + UdpEndpoint *_FindActiveEndpoint(const sockaddr *ourAddress, const sockaddr *peerAddress); status_t _DemuxBroadcast(net_buffer *buffer); - status_t _DemuxMulticast(net_buffer *buffer); status_t _DemuxUnicast(net_buffer *buffer); uint16 _GetNextEphemeral(); + UdpEndpoint *_EndpointWithPort(uint16 port) const; net_address_module_info *AddressModule() const - { - return fDomain->address_module; - } + { return fDomain->address_module; } typedef OpenHashTable EndpointTable; + benaphore fLock; net_domain *fDomain; uint16 fLastUsedEphemeral; EndpointTable fActiveEndpoints; @@ -206,9 +218,6 @@ public: UdpDomainSupport *OpenEndpoint(UdpEndpoint *endpoint); status_t FreeEndpoint(UdpDomainSupport *domain); - uint16 GetEphemeralPort(); - - benaphore *Locker() { return &fLock; } status_t InitCheck() const; private: @@ -233,16 +242,27 @@ net_stack_module_info *gStackModule; UdpDomainSupport::UdpDomainSupport(net_domain *domain) : fDomain(domain), - fLastUsedEphemeral(kLast), fActiveEndpoints(domain->address_module, kNumHashBuckets), fEndpointCount(0) { + benaphore_init(&fLock, "udp domain"); + + fLastUsedEphemeral = kFirst + rand() % (kLast - kFirst); +} + + +UdpDomainSupport::~UdpDomainSupport() +{ + benaphore_destroy(&fLock); } status_t UdpDomainSupport::InitCheck() const { + if (fLock.sem < B_OK) + return fLock.sem; + return fActiveEndpoints.InitCheck(); } @@ -250,26 +270,96 @@ UdpDomainSupport::InitCheck() const status_t UdpDomainSupport::DemuxIncomingBuffer(net_buffer *buffer) { + // NOTE multicast is delivered directly to the endpoint + + BenaphoreLocker _(fLock); + if (buffer->flags & MSG_BCAST) return _DemuxBroadcast(buffer); else if (buffer->flags & MSG_MCAST) - return _DemuxMulticast(buffer); + return B_ERROR; return _DemuxUnicast(buffer); } status_t -UdpDomainSupport::CheckBindRequest(sockaddr *address, int socketOptions) -{ // sUdpEndpointManager->Locker() must be locked! - status_t status = B_OK; +UdpDomainSupport::BindEndpoint(UdpEndpoint *endpoint, + const sockaddr *address) +{ + BenaphoreLocker _(fLock); + + if (endpoint->IsActive()) + return EINVAL; + + return _BindEndpoint(endpoint, address); +} + + +status_t +UdpDomainSupport::ConnectEndpoint(UdpEndpoint *endpoint, + const sockaddr *address) +{ + BenaphoreLocker _(fLock); + + if (endpoint->IsActive()) { + fActiveEndpoints.Remove(endpoint); + endpoint->SetActive(false); + } + + if (address->sa_family == AF_UNSPEC) { + // [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect", + // so we reset the peer address: + endpoint->PeerAddress().SetToEmpty(); + } else { + status_t status = endpoint->PeerAddress().SetTo(address); + if (status < B_OK) + return status; + } + + // we need to activate no matter whether or not we have just disconnected, + // as calling connect() always triggers an implicit bind(): + return _BindEndpoint(endpoint, *endpoint->LocalAddress()); +} + + +status_t +UdpDomainSupport::UnbindEndpoint(UdpEndpoint *endpoint) +{ + BenaphoreLocker _(fLock); + + if (endpoint->IsActive()) + fActiveEndpoints.Remove(endpoint); + + endpoint->SetActive(false); + + return B_OK; +} + + +status_t +UdpDomainSupport::_BindEndpoint(UdpEndpoint *endpoint, + const sockaddr *address) +{ + if (AddressModule()->get_port(address) == 0) + return _BindToEphemeral(endpoint, address); + + return _Bind(endpoint, address); +} + + +status_t +UdpDomainSupport::_Bind(UdpEndpoint *endpoint, const sockaddr *address) +{ + int socketOptions = endpoint->Socket()->options; EndpointTable::Iterator it = fActiveEndpoints.GetIterator(); // Iterate over all active UDP-endpoints and check if the requested bind // is allowed (see figure 22.24 in [Stevens - TCP2, p735]): - TRACE_DOMAIN("CheckBindRequest() for %s...", - AddressString(fDomain, address, true).Data()); + TRACE_DOMAIN("CheckBindRequest() for %s...", AddressString(fDomain, + address, true).Data()); + while (it.HasNext()) { UdpEndpoint *otherEndpoint = it.Next(); @@ -278,56 +368,55 @@ UdpDomainSupport::CheckBindRequest(sockaddr *address, int socketOptions) if (otherEndpoint->LocalAddress().EqualPorts(address)) { // port is already bound, SO_REUSEADDR or SO_REUSEPORT is required: - if (otherEndpoint->socket->options & (SO_REUSEADDR | SO_REUSEPORT) == 0 - || socketOptions & (SO_REUSEADDR | SO_REUSEPORT) == 0) { - status = EADDRINUSE; - break; - } + if (otherEndpoint->Socket()->options & (SO_REUSEADDR | SO_REUSEPORT) == 0 + || socketOptions & (SO_REUSEADDR | SO_REUSEPORT) == 0) + return EADDRINUSE; // if both addresses are the same, SO_REUSEPORT is required: if (otherEndpoint->LocalAddress().EqualTo(address, false) - && (otherEndpoint->socket->options & SO_REUSEPORT == 0 - || socketOptions & SO_REUSEPORT == 0)) { - status = EADDRINUSE; - break; - } + && (otherEndpoint->Socket()->options & SO_REUSEPORT == 0 + || socketOptions & SO_REUSEPORT == 0)) + return EADDRINUSE; } } - TRACE_DOMAIN(" CheckBindRequest done (status=%lx)", status); - return status; + return _FinishBind(endpoint, address); } status_t -UdpDomainSupport::ActivateEndpoint(UdpEndpoint *endpoint) -{ // sUdpEndpointManager->Locker() must be locked! - TRACE_DOMAIN("Endpoint(%s) was activated", - AddressString(fDomain, *endpoint->LocalAddress(), true).Data()); +UdpDomainSupport::_BindToEphemeral(UdpEndpoint *endpoint, + const sockaddr *address) +{ + SocketAddressStorage newAddress(AddressModule()); + status_t status = newAddress.SetTo(address); + if (status < B_OK) + return status; + + uint16 allocedPort = _GetNextEphemeral(); + if (allocedPort == 0) + return ENOBUFS; + + newAddress.SetPort(allocedPort); + + return _FinishBind(endpoint, *newAddress); +} + + +status_t +UdpDomainSupport::_FinishBind(UdpEndpoint *endpoint, const sockaddr *address) +{ + status_t status = endpoint->next->module->bind(endpoint->next, address); + if (status < B_OK) + return status; fActiveEndpoints.Insert(endpoint); + endpoint->SetActive(true); + return B_OK; } -status_t -UdpDomainSupport::DeactivateEndpoint(UdpEndpoint *endpoint) -{ // sUdpEndpointManager->Locker() must be locked! - TRACE_DOMAIN("Endpoint(%s) was deactivated", - AddressString(fDomain, *endpoint->LocalAddress(), true).Data()); - - fActiveEndpoints.Remove(endpoint); - return B_OK; -} - - -uint16 -UdpDomainSupport::GetEphemeralPort() -{ - return _GetNextEphemeral(); -} - - UdpEndpoint * UdpDomainSupport::_FindActiveEndpoint(const sockaddr *ourAddress, const sockaddr *peerAddress) @@ -387,15 +476,6 @@ UdpDomainSupport::_DemuxBroadcast(net_buffer *buffer) } -status_t -UdpDomainSupport::_DemuxMulticast(net_buffer *buffer) -{ // TODO: implement! - TRACE_DOMAIN("_DemuxMulticast(%p)", buffer); - - return B_ERROR; -} - - status_t UdpDomainSupport::_DemuxUnicast(net_buffer *buffer) { @@ -443,48 +523,38 @@ UdpDomainSupport::_GetNextEphemeral() curr = kFirst; } - TRACE_DOMAIN("_GetNextEphemeral()"); // TODO: a free list could be used to avoid the impact of these // two nested loops most of the time... let's see how bad this really is - bool found = false; - EndpointTable::Iterator it = fActiveEndpoints.GetIterator(); + TRACE_DOMAIN("_GetNextEphemeral(), last %hu, curr %hu, stop %hu", + fLastUsedEphemeral, curr, stop); - while (!found && curr != stop) { + for (; curr != stop; curr = (curr < kLast) ? (curr + 1) : kFirst) { TRACE_DOMAIN(" _GetNextEphemeral(): trying port %hu...", curr); - it.Rewind(); - - while (!found) { - if (!it.HasNext()) { - found = true; - break; - } - - UdpEndpoint *endpoint = it.Next(); - uint16 endpointPort = endpoint->LocalAddress().Port(); - - TRACE_DOMAIN(" _GetNextEphemeral(): checking endpoint %p (port %hu)...", - endpoint, ntohs(endpointPort)); - - if (endpointPort == htons(curr)) - break; - } - - if (!found) { - if (curr < kLast) - curr++; - else - curr = kFirst; + if (_EndpointWithPort(htons(curr)) == NULL) { + TRACE_DOMAIN(" _GetNextEphemeral(): ...using port %hu", curr); + fLastUsedEphemeral = curr; + return curr; } } - if (!found) - return 0; + return 0; +} - TRACE_DOMAIN(" _GetNextEphemeral(): ...using port %hu", curr); - fLastUsedEphemeral = curr; - return curr; + +UdpEndpoint * +UdpDomainSupport::_EndpointWithPort(uint16 port) const +{ + EndpointTable::Iterator it = fActiveEndpoints.GetIterator(); + + while (it.HasNext()) { + UdpEndpoint *endpoint = it.Next(); + if (endpoint->LocalAddress().Port() == port) + return endpoint; + } + + return NULL; } @@ -524,9 +594,18 @@ UdpEndpointManager::ReceiveData(net_buffer *buffer) net_domain *domain = buffer->interface->domain; - BenaphoreLocker _(fLock); + UdpDomainSupport *domainSupport = NULL; + + { + BenaphoreLocker _(fLock); + domainSupport = _GetDomain(domain, false); + // TODO we don't want to hold to the manager's lock + // during the whole RX path, we may not hold an + // endpoint's lock with the manager lock held. + // But we should increase the domain's refcount + // here. + } - UdpDomainSupport *domainSupport = _GetDomain(domain, false); if (domainSupport == NULL) { // we don't instantiate domain supports in the // RX path as we are only interested in delivering @@ -619,9 +698,7 @@ UdpEndpointManager::FreeEndpoint(UdpDomainSupport *domain) { BenaphoreLocker _(fLock); - domain->Put(); - - if (domain->IsEmpty()) { + if (domain->Put()) { fDomains.Remove(domain); delete domain; } @@ -679,32 +756,7 @@ status_t UdpEndpoint::Bind(const sockaddr *address) { TRACE_EP("Bind(%s)", AddressString(Domain(), address, true).Data()); - - // let the underlying protocol check whether there is an interface that - // supports the given address - status_t status = next->module->bind(next, address); - if (status < B_OK) - return status; - - BenaphoreLocker locker(sUdpEndpointManager->Locker()); - - if (fActive) { - // socket module should have called unbind() before! - return EINVAL; - } - - if (AddressModule()->get_port(address) == 0) { - uint16 port = htons(fManager->GetEphemeralPort()); - if (port == 0) - return ENOBUFS; - LocalAddress().SetPort(port); - } else { - status = fManager->CheckBindRequest(*LocalAddress(), socket->options); - if (status < B_OK) - return status; - } - - return _Activate(); + return fManager->BindEndpoint(this, address); } @@ -712,10 +764,7 @@ status_t UdpEndpoint::Unbind(sockaddr *address) { TRACE_EP("Unbind()"); - - BenaphoreLocker locker(sUdpEndpointManager->Locker()); - - return _Deactivate(); + return fManager->UnbindEndpoint(this); } @@ -723,24 +772,7 @@ status_t UdpEndpoint::Connect(const sockaddr *address) { TRACE_EP("Connect(%s)", AddressString(Domain(), address, true).Data()); - - BenaphoreLocker locker(sUdpEndpointManager->Locker()); - - if (fActive) - _Deactivate(); - - if (address->sa_family == AF_UNSPEC) { - // [Stevens-UNP1, p226]: specifying AF_UNSPEC requests a "disconnect", - // so we reset the peer address: - PeerAddress().SetToEmpty(); - } else { - // TODO check if `address' is compatible with AddressModule() - PeerAddress().SetTo(address); - } - - // we need to activate no matter whether or not we have just disconnected, - // as calling connect() always triggers an implicit bind(): - return _Activate(); + return fManager->ConnectEndpoint(this, address); } @@ -749,6 +781,8 @@ UdpEndpoint::Open() { TRACE_EP("Open()"); + BenaphoreLocker _(fLock); + status_t status = ProtocolSocket::Open(); if (status < B_OK) return status; @@ -765,11 +799,6 @@ status_t UdpEndpoint::Close() { TRACE_EP("Close()"); - - BenaphoreLocker _(sUdpEndpointManager->Locker()); - if (fActive) - _Deactivate(); - return B_OK; } @@ -778,32 +807,10 @@ status_t UdpEndpoint::Free() { TRACE_EP("Free()"); - return sUdpEndpointManager->FreeEndpoint(fManager); } -status_t -UdpEndpoint::_Activate() -{ - if (fActive) - return B_ERROR; - status_t status = fManager->ActivateEndpoint(this); - fActive = (status == B_OK); - return status; -} - - -status_t -UdpEndpoint::_Deactivate() -{ - if (!fActive) - return B_ERROR; - fActive = false; - return fManager->DeactivateEndpoint(this); -} - - // #pragma mark - outbound