introduced SocketAddress wrappers and use them in TCP.

git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@20818 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Hugo Santos 2007-04-25 17:41:01 +00:00
parent c36e4d4c9f
commit 585195c28d
5 changed files with 254 additions and 67 deletions

View File

@ -0,0 +1,191 @@
/*
* Copyright 2007, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Hugo Santos, hugosantos@gmail.com
*/
#ifndef ADDRESS_UTILITIES_H
#define ADDRESS_UTILITIES_H
#include <net_datalink.h>
#include <sys/socket.h>
class SocketAddress {
public:
SocketAddress(net_address_module_info *module, sockaddr *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo(address);
}
SocketAddress(net_address_module_info *module, sockaddr_storage *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo((sockaddr *)address);
}
SocketAddress(const SocketAddress &address)
: fModule(address.fModule), fAddress(address.fAddress)
{}
void SetAddressTo(sockaddr *address)
{
fAddress = address;
}
bool IsEmpty(bool checkPort) const
{
return fModule->is_empty_address(fAddress, checkPort);
}
uint32 HashPair(const sockaddr *second) const
{
return fModule->hash_address_pair(fAddress, second);
}
bool EqualTo(const sockaddr *address, bool checkPort = false) const
{
if (checkPort)
return fModule->equal_addresses_and_ports(fAddress, address);
else
return fModule->equal_addresses(fAddress, address);
}
bool EqualTo(const SocketAddress &second, bool checkPort = false) const
{
return EqualTo(second.fAddress, checkPort);
}
uint16 GetPort() const
{
return fModule->get_port(fAddress);
}
status_t SetTo(const sockaddr *from)
{
return fModule->set_to(fAddress, from);
}
status_t SetTo(const sockaddr_storage *from)
{
return SetTo((sockaddr *)from);
}
status_t CopyTo(sockaddr *to) const
{
return fModule->set_to(to, fAddress);
}
status_t CopyTo(sockaddr_storage *to) const
{
return CopyTo((sockaddr *)to);
}
status_t SetPort(uint16 port)
{
return fModule->set_port(fAddress, port);
}
status_t SetToEmpty()
{
return fModule->set_to_empty_address(fAddress);
}
status_t UpdateTo(const sockaddr *from)
{
return fModule->update_to(fAddress, from);
}
const sockaddr *operator *() const { return fAddress; }
sockaddr *operator *() { return fAddress; }
net_address_module_info *Module() const { return fModule; }
private:
net_address_module_info *fModule;
sockaddr *fAddress;
};
class ConstSocketAddress {
public:
ConstSocketAddress(net_address_module_info *module,
const sockaddr *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo(address);
}
ConstSocketAddress(net_address_module_info *module,
const sockaddr_storage *address)
: fModule(module), fAddress(NULL)
{
SetAddressTo((sockaddr *)address);
}
ConstSocketAddress(const ConstSocketAddress &address)
: fModule(address.fModule), fAddress(address.fAddress)
{}
ConstSocketAddress(const SocketAddress &address)
: fModule(address.Module()), fAddress(*address)
{}
void SetAddressTo(const sockaddr *address)
{
fAddress = address;
}
bool IsEmpty(bool checkPort) const
{
return fModule->is_empty_address(fAddress, checkPort);
}
uint32 HashPair(const sockaddr *second) const
{
return fModule->hash_address_pair(fAddress, second);
}
bool EqualTo(const sockaddr *address, bool checkPort = false) const
{
if (checkPort)
return fModule->equal_addresses_and_ports(fAddress, address);
else
return fModule->equal_addresses(fAddress, address);
}
uint16 GetPort() const
{
return fModule->get_port(fAddress);
}
status_t CopyTo(sockaddr *to) const
{
return fModule->set_to(to, fAddress);
}
status_t CopyTo(sockaddr_storage *to) const
{
return CopyTo((sockaddr *)to);
}
const sockaddr *operator *() const { return fAddress; }
private:
net_address_module_info *fModule;
const sockaddr *fAddress;
};
class SocketAddressStorage : public SocketAddress {
public:
SocketAddressStorage(net_address_module_info *module)
: SocketAddress(module, &fAddressStorage)
{}
private:
sockaddr_storage fAddressStorage;
};
#endif

View File

@ -32,15 +32,15 @@ static const uint16 kFirstEphemeralPort = 40000;
size_t size_t
ConnectionHashDefinition::HashKey(EndpointManager *manager, const KeyType &key) ConnectionHashDefinition::HashKey(EndpointManager *manager, const KeyType &key)
{ {
return manager->AddressModule()->hash_address_pair(key.first, key.second); return ConstSocketAddress(manager->AddressModule(),
key.first).HashPair(key.second);
} }
size_t size_t
ConnectionHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint) ConnectionHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint)
{ {
return manager->AddressModule()->hash_address_pair( return endpoint->LocalAddress().HashPair(*endpoint->PeerAddress());
endpoint->LocalAddress(), endpoint->PeerAddress());
} }
@ -48,10 +48,8 @@ bool
ConnectionHashDefinition::Compare(EndpointManager *manager, const KeyType &key, ConnectionHashDefinition::Compare(EndpointManager *manager, const KeyType &key,
TCPEndpoint *endpoint) TCPEndpoint *endpoint)
{ {
net_address_module_info *module = manager->AddressModule(); return endpoint->LocalAddress().EqualTo(key.first, true)
&& endpoint->PeerAddress().EqualTo(key.second, true);
return module->equal_addresses_and_ports(key.first, endpoint->LocalAddress())
&& module->equal_addresses_and_ports(key.second, endpoint->PeerAddress());
} }
@ -65,7 +63,7 @@ EndpointHashDefinition::HashKey(EndpointManager *manager, uint16 port)
size_t size_t
EndpointHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint) EndpointHashDefinition::Hash(EndpointManager *manager, TCPEndpoint *endpoint)
{ {
return endpoint->AddressModule()->get_port(endpoint->LocalAddress()); return endpoint->LocalAddress().GetPort();
} }
@ -73,7 +71,7 @@ bool
EndpointHashDefinition::Compare(EndpointManager *manager, uint16 port, EndpointHashDefinition::Compare(EndpointManager *manager, uint16 port,
TCPEndpoint *endpoint) TCPEndpoint *endpoint)
{ {
return endpoint->AddressModule()->get_port(endpoint->LocalAddress()) == port; return endpoint->LocalAddress().GetPort() == port;
} }
@ -122,25 +120,26 @@ EndpointManager::_LookupConnection(const sockaddr *local, const sockaddr *peer)
status_t status_t
EndpointManager::SetConnection(TCPEndpoint *endpoint, EndpointManager::SetConnection(TCPEndpoint *endpoint,
const sockaddr *local, const sockaddr *peer, const sockaddr *interfaceLocal) const sockaddr *_local, const sockaddr *peer, const sockaddr *interfaceLocal)
{ {
TRACE(("EndpointManager::SetConnection(%p)\n", endpoint)); TRACE(("EndpointManager::SetConnection(%p)\n", endpoint));
BenaphoreLocker _(fLock); BenaphoreLocker _(fLock);
sockaddr localBuffer;
// need to associate this connection with a real address, not INADDR_ANY SocketAddressStorage local(AddressModule());
if (AddressModule()->is_empty_address(local, false)) { local.SetTo(_local);
AddressModule()->set_to(&localBuffer, interfaceLocal);
AddressModule()->set_port(&localBuffer, AddressModule()->get_port(local)); if (local.IsEmpty(false)) {
local = &localBuffer; uint16 port = local.GetPort();
local.SetTo(interfaceLocal);
local.SetPort(port);
} }
if (_LookupConnection(local, peer) != NULL) if (_LookupConnection(*local, peer) != NULL)
return EADDRINUSE; return EADDRINUSE;
AddressModule()->set_to(endpoint->LocalAddress(), local); endpoint->LocalAddress().SetTo(*local);
AddressModule()->set_to(endpoint->PeerAddress(), peer); endpoint->PeerAddress().SetTo(peer);
if (!fConnectionHash.Insert(endpoint)) if (!fConnectionHash.Insert(endpoint))
return B_NO_MEMORY; return B_NO_MEMORY;
@ -156,22 +155,21 @@ EndpointManager::SetPassive(TCPEndpoint *endpoint)
if (!endpoint->IsBound()) { if (!endpoint->IsBound()) {
// if the socket is unbound first bind it to ephemeral // if the socket is unbound first bind it to ephemeral
sockaddr_storage localAddress; SocketAddressStorage local(AddressModule());
AddressModule()->set_to_empty_address((sockaddr *)&localAddress); local.SetToEmpty();
status_t status = _BindToEphemeral(endpoint, status_t status = _BindToEphemeral(endpoint, *local);
(sockaddr *)&localAddress);
if (status < B_OK) if (status < B_OK)
return status; return status;
} }
sockaddr_storage passive; SocketAddressStorage passive(AddressModule());
AddressModule()->set_to_empty_address((sockaddr *)&passive); passive.SetToEmpty();
if (_LookupConnection(endpoint->LocalAddress(), (sockaddr *)&passive)) if (_LookupConnection(*endpoint->LocalAddress(), *passive))
return EADDRINUSE; return EADDRINUSE;
AddressModule()->set_to(endpoint->PeerAddress(), (sockaddr *)&passive); endpoint->PeerAddress().SetTo(*passive);
if (!fConnectionHash.Insert(endpoint)) if (!fConnectionHash.Insert(endpoint))
return B_NO_MEMORY; return B_NO_MEMORY;
@ -192,20 +190,20 @@ EndpointManager::FindConnection(sockaddr *local, sockaddr *peer)
// no explicit endpoint exists, check for wildcard endpoints // no explicit endpoint exists, check for wildcard endpoints
sockaddr wildcard; SocketAddressStorage wildcard(AddressModule());
AddressModule()->set_to_empty_address(&wildcard); wildcard.SetToEmpty();
endpoint = _LookupConnection(local, &wildcard); endpoint = _LookupConnection(local, *wildcard);
if (endpoint != NULL) { if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint)); TRACE(("TCP: Received packet corresponds to wildcard endpoint %p\n", endpoint));
return endpoint; return endpoint;
} }
sockaddr localWildcard; SocketAddressStorage localWildcard(AddressModule());
AddressModule()->set_to_empty_address(&localWildcard); localWildcard.SetToEmpty();
AddressModule()->set_port(&localWildcard, AddressModule()->get_port(local)); localWildcard.SetPort(AddressModule()->get_port(local));
endpoint = _LookupConnection(&localWildcard, &wildcard); endpoint = _LookupConnection(*localWildcard, *wildcard);
if (endpoint != NULL) { if (endpoint != NULL) {
TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint)); TRACE(("TCP: Received packet corresponds to local wildcard endpoint %p\n", endpoint));
return endpoint; return endpoint;
@ -276,15 +274,15 @@ EndpointManager::_BindToEphemeral(TCPEndpoint *endpoint,
TCPEndpoint *other = fEndpointHash.Lookup(port); TCPEndpoint *other = fEndpointHash.Lookup(port);
if (other == NULL) { if (other == NULL) {
sockaddr_storage newAddress; SocketAddressStorage newAddress(AddressModule());
AddressModule()->set_to((sockaddr *)&newAddress, address); newAddress.SetTo(address);
AddressModule()->set_port((sockaddr *)&newAddress, port); newAddress.SetPort(port);
// found a port // found a port
TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint, TRACE((" EndpointManager::BindToEphemeral(%p) -> %s\n", endpoint,
AddressString(Domain(), (sockaddr *)&newAddress, true).Data())); AddressString(Domain(), *newAddress, true).Data()));
return _Bind(endpoint, (sockaddr *)&newAddress); return _Bind(endpoint, *newAddress);
} }
counter += step; counter += step;
@ -317,9 +315,9 @@ EndpointManager::_Bind(TCPEndpoint *endpoint, const sockaddr *address)
if (first != NULL) { if (first != NULL) {
while (true) { while (true) {
// check if this endpoint binds to a wildcard address // check if this endpoint binds to a wildcard address
if (AddressModule()->is_empty_address(first->LocalAddress(), false)) { if (first->LocalAddress().IsEmpty(false)) {
// you cannot specialize a wildcard endpoint - you have to open the // you cannot specialize a wildcard endpoint - you have to open
// wildcard endpoint last // the wildcard endpoint last
return B_PERMISSION_DENIED; return B_PERMISSION_DENIED;
} }
@ -361,8 +359,8 @@ EndpointManager::Unbind(TCPEndpoint *endpoint)
BenaphoreLocker _(fLock); BenaphoreLocker _(fLock);
TCPEndpoint *other = fEndpointHash.Lookup( TCPEndpoint *other =
AddressModule()->get_port(endpoint->LocalAddress())); fEndpointHash.Lookup(endpoint->LocalAddress().GetPort());
if (other != endpoint) { if (other != endpoint) {
// remove endpoint from the list of endpoints with the same port // remove endpoint from the list of endpoints with the same port
while (other != NULL && other->fEndpointNextWithSamePort != endpoint) while (other != NULL && other->fEndpointNextWithSamePort != endpoint)
@ -384,7 +382,7 @@ EndpointManager::Unbind(TCPEndpoint *endpoint)
endpoint->fEndpointNextWithSamePort = NULL; endpoint->fEndpointNextWithSamePort = NULL;
fConnectionHash.Remove(endpoint); fConnectionHash.Remove(endpoint);
endpoint->LocalAddress()->sa_len = 0; (*endpoint->LocalAddress())->sa_len = 0;
return B_OK; return B_OK;
} }

View File

@ -10,14 +10,12 @@
#include "tcp.h" #include "tcp.h"
#include <net_datalink.h> #include <AddressUtilities.h>
#include <lock.h> #include <lock.h>
#include <util/DoublyLinkedList.h> #include <util/DoublyLinkedList.h>
#include <util/OpenHashTable.h> #include <util/OpenHashTable.h>
#include <sys/socket.h>
#include <utility> #include <utility>

View File

@ -707,7 +707,7 @@ TCPEndpoint::SetReceiveBufferSize(size_t length)
bool bool
TCPEndpoint::IsBound() const TCPEndpoint::IsBound() const
{ {
return !AddressModule()->is_empty_address(LocalAddress(), true); return !LocalAddress().IsEmpty(true);
} }
@ -791,13 +791,12 @@ TCPEndpoint::_ListenReceive(tcp_segment_header &segment, net_buffer *buffer)
if (gSocketModule->spawn_pending_socket(socket, &newSocket) < B_OK) if (gSocketModule->spawn_pending_socket(socket, &newSocket) < B_OK)
return DROP; return DROP;
AddressModule()->set_to((sockaddr *)&newSocket->address, TCPEndpoint *newEndpoint = (TCPEndpoint *)newSocket->first_protocol;
(sockaddr *)&buffer->destination);
AddressModule()->set_to((sockaddr *)&newSocket->peer,
(sockaddr *)&buffer->source);
return ((TCPEndpoint *)newSocket->first_protocol)->Spawn(this, newEndpoint->LocalAddress().SetTo(&buffer->destination);
segment, buffer); newEndpoint->PeerAddress().SetTo(&buffer->source);
return newEndpoint->Spawn(this, segment, buffer);
} }
@ -814,7 +813,7 @@ TCPEndpoint::Spawn(TCPEndpoint *parent, tcp_segment_header &segment,
TRACE("Spawn()"); TRACE("Spawn()");
// TODO: proper error handling! // TODO: proper error handling!
if (_PrepareSendPath(PeerAddress()) < B_OK) if (_PrepareSendPath(*PeerAddress()) < B_OK)
return DROP; return DROP;
fOptions = parent->fOptions; fOptions = parent->fOptions;
@ -1158,8 +1157,8 @@ TCPEndpoint::_SendQueued(bool force, uint32 sendWindow)
return status; return status;
} }
AddressModule()->set_to((sockaddr *)&buffer->source, LocalAddress()); LocalAddress().CopyTo(&buffer->source);
AddressModule()->set_to((sockaddr *)&buffer->destination, PeerAddress()); PeerAddress().CopyTo(&buffer->destination);
uint32 size = buffer->size; uint32 size = buffer->size;
segment.sequence = fSendNext; segment.sequence = fSendNext;
@ -1616,7 +1615,7 @@ TCPEndpoint::_PrepareSendPath(const sockaddr *peer)
} }
// make sure connection does not already exist // make sure connection does not already exist
status_t status = fManager->SetConnection(this, LocalAddress(), peer, status_t status = fManager->SetConnection(this, *LocalAddress(), peer,
fRoute->interface->address); fRoute->interface->address);
if (status < B_OK) if (status < B_OK)
return status; return status;

View File

@ -13,6 +13,7 @@
#include "tcp.h" #include "tcp.h"
#include "BufferQueue.h" #include "BufferQueue.h"
#include <AddressUtilities.h>
#include <net_protocol.h> #include <net_protocol.h>
#include <net_stack.h> #include <net_stack.h>
#include <util/AutoLock.h> #include <util/AutoLock.h>
@ -65,15 +66,15 @@ class TCPEndpoint : public net_protocol {
tcp_state State() const { return fState; } tcp_state State() const { return fState; }
bool IsBound() const; bool IsBound() const;
const sockaddr *LocalAddress() const SocketAddress LocalAddress()
{ return (sockaddr *)&socket->address; } { return SocketAddress(AddressModule(), &socket->address); }
sockaddr *LocalAddress() ConstSocketAddress LocalAddress() const
{ return (sockaddr *)&socket->address; } { return ConstSocketAddress(AddressModule(), &socket->address); }
const sockaddr *PeerAddress() const SocketAddress PeerAddress()
{ return (sockaddr *)&socket->peer; } { return SocketAddress(AddressModule(), &socket->peer); }
sockaddr *PeerAddress() ConstSocketAddress PeerAddress() const
{ return (sockaddr *)&socket->peer; } { return ConstSocketAddress(AddressModule(), &socket->peer); }
void DeleteSocket(); void DeleteSocket();