introduced SocketAddress wrappers and use them in TCP.

git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@20818 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Hugo Santos 2007-04-25 17:41:01 +00:00
parent c36e4d4c9f
commit 585195c28d
5 changed files with 254 additions and 67 deletions

View File

@ -0,0 +1,191 @@
/*
* Copyright 2007, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Hugo Santos, hugosantos@gmail.com
*/
#ifndef ADDRESS_UTILITIES_H
#define ADDRESS_UTILITIES_H
#include <net_datalink.h>
#include <sys/socket.h>
class SocketAddress {
public:
SocketAddress(net_address_module_info *module, sockaddr *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo(address);
}
SocketAddress(net_address_module_info *module, sockaddr_storage *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo((sockaddr *)address);
}
SocketAddress(const SocketAddress &address)
: fModule(address.fModule), fAddress(address.fAddress)
{}
void SetAddressTo(sockaddr *address)
{
fAddress = address;
}
bool IsEmpty(bool checkPort) const
{
return fModule->is_empty_address(fAddress, checkPort);
}
uint32 HashPair(const sockaddr *second) const
{
return fModule->hash_address_pair(fAddress, second);
}
bool EqualTo(const sockaddr *address, bool checkPort = false) const
{
if (checkPort)
return fModule->equal_addresses_and_ports(fAddress, address);
else
return fModule->equal_addresses(fAddress, address);
}
bool EqualTo(const SocketAddress &second, bool checkPort = false) const
{
return EqualTo(second.fAddress, checkPort);
}
uint16 GetPort() const
{
return fModule->get_port(fAddress);
}
status_t SetTo(const sockaddr *from)
{
return fModule->set_to(fAddress, from);
}
status_t SetTo(const sockaddr_storage *from)
{
return SetTo((sockaddr *)from);
}
status_t CopyTo(sockaddr *to) const
{
return fModule->set_to(to, fAddress);
}
status_t CopyTo(sockaddr_storage *to) const
{
return CopyTo((sockaddr *)to);
}
status_t SetPort(uint16 port)
{
return fModule->set_port(fAddress, port);
}
status_t SetToEmpty()
{
return fModule->set_to_empty_address(fAddress);
}
status_t UpdateTo(const sockaddr *from)
{
return fModule->update_to(fAddress, from);
}
const sockaddr *operator *() const { return fAddress; }
sockaddr *operator *() { return fAddress; }
net_address_module_info *Module() const { return fModule; }
private:
net_address_module_info *fModule;
sockaddr *fAddress;
};
class ConstSocketAddress {
public:
ConstSocketAddress(net_address_module_info *module,
const sockaddr *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo(address);
}
ConstSocketAddress(net_address_module_info *module,
const sockaddr_storage *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo((sockaddr *)address);
}
ConstSocketAddress(const ConstSocketAddress &address)
: fModule(address.fModule), fAddress(address.fAddress)
{}
ConstSocketAddress(const SocketAddress &address)
: fModule(address.Module()), fAddress(*address)
{}
void SetAddressTo(const sockaddr *address)
{
fAddress = address;
}
bool IsEmpty(bool checkPort) const
{
return fModule->is_empty_address(fAddress, checkPort);
}
uint32 HashPair(const sockaddr *second) const
{
return fModule->hash_address_pair(fAddress, second);
}
bool EqualTo(const sockaddr *address, bool checkPort = false) const
{
if (checkPort)
return fModule->equal_addresses_and_ports(fAddress, address);
else
return fModule->equal_addresses(fAddress, address);
}
uint16 GetPort() const
{
return fModule->get_port(fAddress);
}
status_t CopyTo(sockaddr *to) const
{
return fModule->set_to(to, fAddress);
}
status_t CopyTo(sockaddr_storage *to) const
{
return CopyTo((sockaddr *)to);
}
const sockaddr *operator *() const { return fAddress; }
private:
net_address_module_info *fModule;
const sockaddr *fAddress;
};
class SocketAddressStorage : public SocketAddress {
public:
SocketAddressStorage(net_address_module_info *module)
: SocketAddress(module, &fAddressStorage)
{}
private:
sockaddr_storage fAddressStorage;
};
#endif

View File

@ -32,15 +32,15 @@ static const uint16 kFirstEphemeralPort = 40000;
size_t
ConnectionHashDefinition::HashKey(EndpointManager *manager, const KeyType &key)
{
return manager->AddressModule()->hash_address_pair(key.first, key.second);
return ConstSocketAddress(manager->AddressModule(),
key.first).HashPair(key.second);
}
size_t
ConnectionHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint)
{
return manager->AddressModule()->hash_address_pair(
endpoint->LocalAddress(), endpoint->PeerAddress());
return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
}
@ -48,10 +48,8 @@ bool
ConnectionHashDefinition::Compare(EndpointManager *manager, const KeyType &key,
TCPEndpoint *endpoint)
{
net_address_module_info *module = manager->AddressModule();
return module->equal_addresses_and_ports(key.first, endpoint->LocalAddress())
&& module->equal_addresses_and_ports(key.second, endpoint->PeerAddress());
return endpoint->LocalAddress().EqualTo(key.first, true)
&& endpoint->PeerAddress().EqualTo(key.second, true);
}
@ -65,7 +63,7 @@ EndpointHashDefinition::HashKey(EndpointManager *manager, uint16 port)
size_t
EndpointHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint)
{
return endpoint->AddressModule()->get_port(endpoint->LocalAddress());
return endpoint->LocalAddress().GetPort();
}
@ -73,7 +71,7 @@ bool
EndpointHashDefinition::Compare(EndpointManager *manager, uint16 port,
TCPEndpoint *endpoint)
{
return endpoint->AddressModule()->get_port(endpoint->LocalAddress()) == port;
return endpoint->LocalAddress().GetPort() == port;
}
@ -122,25 +120,26 @@ EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer)
status_t
EndpointManager::SetConnection(TCPEndpoint *endpoint,
const sockaddr *local, const sockaddr *peer, const sockaddr *interfaceLocal)
const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal)
{
TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
BenaphoreLocker _(fLock);
sockaddr localBuffer;
// need to associate this connection with a real address, not INADDR_ANY
if (AddressModule()->is_empty_address(local, false)) {
AddressModule()->set_to(&localBuffer, interfaceLocal);
AddressModule()->set_port(&localBuffer, AddressModule()->get_port(local));
local = &localBuffer;
SocketAddressStorage local(AddressModule());
local.SetTo(_local);
if (local.IsEmpty(false)) {
uint16 port = local.GetPort();
local.SetTo(interfaceLocal);
local.SetPort(port);
}
if (_LookupConnection(local, peer) != NULL)
if (_LookupConnection(*local, peer) != NULL)
return EADDRINUSE;
AddressModule()->set_to(endpoint->LocalAddress(), local);
AddressModule()->set_to(endpoint->PeerAddress(), peer);
endpoint->LocalAddress().SetTo(*local);
endpoint->PeerAddress().SetTo(peer);
if (!fConnectionHash.Insert(endpoint))
return B_NO_MEMORY;
@ -156,22 +155,21 @@ EndpointManager::SetPassive(TCPEndpoint *endpoint)
if (!endpoint->IsBound()) {
// if the socket is unbound first bind it to ephemeral
sockaddr_storage localAddress;
AddressModule()->set_to_empty_address((sockaddr *)&localAddress);
SocketAddressStorage local(AddressModule());
local.SetToEmpty();
status_t status = _BindToEphemeral(endpoint,
(sockaddr *)&localAddress);
status_t status = _BindToEphemeral(endpoint, *local);
if (status < B_OK)
return status;
}
sockaddr_storage passive;
AddressModule()->set_to_empty_address((sockaddr *)&passive);
SocketAddressStorage passive(AddressModule());
passive.SetToEmpty();
if (_LookupConnection(endpoint->LocalAddress(), (sockaddr *)&passive))
if (_LookupConnection(*endpoint->LocalAddress(), *passive))
return EADDRINUSE;
AddressModule()->set_to(endpoint->PeerAddress(), (sockaddr *)&passive);
endpoint->PeerAddress().SetTo(*passive);
if (!fConnectionHash.Insert(endpoint))
return B_NO_MEMORY;
@ -192,20 +190,20 @@ EndpointManager::FindConnection(sockaddr *local, sockaddr *peer)
// no explicit endpoint exists, check for wildcard endpoints
sockaddr wildcard;
AddressModule()->set_to_empty_address(&wildcard);
SocketAddressStorage wildcard(AddressModule());
wildcard.SetToEmpty();
endpoint = _LookupConnection(local, &wildcard);
endpoint = _LookupConnection(local, *wildcard);
if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint));
return endpoint;
}
sockaddr localWildcard;
AddressModule()->set_to_empty_address(&localWildcard);
AddressModule()->set_port(&localWildcard, AddressModule()->get_port(local));
SocketAddressStorage localWildcard(AddressModule());
localWildcard.SetToEmpty();
localWildcard.SetPort(AddressModule()->get_port(local));
endpoint = _LookupConnection(&localWildcard, &wildcard);
endpoint = _LookupConnection(*localWildcard, *wildcard);
if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint));
return endpoint;
@ -276,15 +274,15 @@ EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint,
TCPEndpoint *other = fEndpointHash.Lookup(port);
if (other == NULL) {
sockaddr_storage newAddress;
AddressModule()->set_to((sockaddr *)&newAddress, address);
AddressModule()->set_port((sockaddr *)&newAddress, port);
SocketAddressStorage newAddress(AddressModule());
newAddress.SetTo(address);
newAddress.SetPort(port);
// found a port
TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint,
AddressString(Domain(), (sockaddr *)&newAddress, true).Data()));
AddressString(Domain(), *newAddress, true).Data()));
return _Bind(endpoint, (sockaddr *)&newAddress);
return _Bind(endpoint, *newAddress);
}
counter += step;
@ -317,9 +315,9 @@ EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address)
if (first != NULL) {
while (true) {
// check if this endpoint binds to a wildcard address
if (AddressModule()->is_empty_address(first->LocalAddress(), false)) {
// you cannot specialize a wildcard endpoint - you have to open the
// wildcard endpoint last
if (first->LocalAddress().IsEmpty(false)) {
// you cannot specialize a wildcard endpoint - you have to open
// the wildcard endpoint last
return B_PERMISSION_DENIED;
}
@ -361,8 +359,8 @@ EndpointManager::Unbind(TCPEndpoint *endpoint)
BenaphoreLocker _(fLock);
TCPEndpoint *other = fEndpointHash.Lookup(
AddressModule()->get_port(endpoint->LocalAddress()));
TCPEndpoint *other =
fEndpointHash.Lookup(endpoint->LocalAddress().GetPort());
if (other != endpoint) {
// remove endpoint from the list of endpoints with the same port
while (other != NULL && other->fEndpointNextWithSamePort != endpoint)
@ -384,7 +382,7 @@ EndpointManager::Unbind(TCPEndpoint *endpoint)
endpoint->fEndpointNextWithSamePort = NULL;
fConnectionHash.Remove(endpoint);
endpoint->LocalAddress()->sa_len = 0;
(*endpoint->LocalAddress())->sa_len = 0;
return B_OK;
}

View File

@ -10,14 +10,12 @@
#include "tcp.h"
#include <net_datalink.h>
#include <AddressUtilities.h>
#include <lock.h>
#include <util/DoublyLinkedList.h>
#include <util/OpenHashTable.h>
#include <sys/socket.h>
#include <utility>

View File

@ -707,7 +707,7 @@ TCPEndpoint::SetReceiveBufferSize(size_t length)
bool
TCPEndpoint::IsBound() const
{
return !AddressModule()->is_empty_address(LocalAddress(), true);
return !LocalAddress().IsEmpty(true);
}
@ -791,13 +791,12 @@ TCPEndpoint::_ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
if (gSocketModule->spawn_pending_socket(socket, &newSocket) < B_OK)
return DROP;
AddressModule()->set_to((sockaddr *)&newSocket->address,
(sockaddr *)&buffer->destination);
AddressModule()->set_to((sockaddr *)&newSocket->peer,
(sockaddr *)&buffer->source);
TCPEndpoint *newEndpoint = (TCPEndpoint *)newSocket->first_protocol;
return ((TCPEndpoint *)newSocket->first_protocol)->Spawn(this,
segment, buffer);
newEndpoint->LocalAddress().SetTo(&buffer->destination);
newEndpoint->PeerAddress().SetTo(&buffer->source);
return newEndpoint->Spawn(this, segment, buffer);
}
@ -814,7 +813,7 @@ TCPEndpoint::Spawn(TCPEndpoint *parent, tcp_segment_header &segment,
TRACE("Spawn()");
// TODO: proper error handling!
if (_PrepareSendPath(PeerAddress()) < B_OK)
if (_PrepareSendPath(*PeerAddress()) < B_OK)
return DROP;
fOptions = parent->fOptions;
@ -1158,8 +1157,8 @@ TCPEndpoint::_SendQueued(bool force, uint32 sendWindow)
return status;
}
AddressModule()->set_to((sockaddr *)&buffer->source, LocalAddress());
AddressModule()->set_to((sockaddr *)&buffer->destination, PeerAddress());
LocalAddress().CopyTo(&buffer->source);
PeerAddress().CopyTo(&buffer->destination);
uint32 size = buffer->size;
segment.sequence = fSendNext;
@ -1616,7 +1615,7 @@ TCPEndpoint::_PrepareSendPath(const sockaddr *peer)
}
// make sure connection does not already exist
status_t status = fManager->SetConnection(this, LocalAddress(), peer,
status_t status = fManager->SetConnection(this, *LocalAddress(), peer,
fRoute->interface->address);
if (status < B_OK)
return status;

View File

@ -13,6 +13,7 @@
#include "tcp.h"
#include "BufferQueue.h"
#include <AddressUtilities.h>
#include <net_protocol.h>
#include <net_stack.h>
#include <util/AutoLock.h>
@ -65,15 +66,15 @@ class TCPEndpoint : public net_protocol {
tcp_state State() const { return fState; }
bool IsBound() const;
const sockaddr *LocalAddress() const
{ return (sockaddr *)&socket->address; }
sockaddr *LocalAddress()
{ return (sockaddr *)&socket->address; }
SocketAddress LocalAddress()
{ return SocketAddress(AddressModule(), &socket->address); }
ConstSocketAddress LocalAddress() const
{ return ConstSocketAddress(AddressModule(), &socket->address); }
const sockaddr *PeerAddress() const
{ return (sockaddr *)&socket->peer; }
sockaddr *PeerAddress()
{ return (sockaddr *)&socket->peer; }
SocketAddress PeerAddress()
{ return SocketAddress(AddressModule(), &socket->peer); }
ConstSocketAddress PeerAddress() const
{ return ConstSocketAddress(AddressModule(), &socket->peer); }
void DeleteSocket();