ipv4: changed the multicast filter to use an hash table to keep source states.

git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@20920 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Hugo Santos 2007-04-30 12:31:31 +00:00
parent 4231541414
commit 0e30c21c70
4 changed files with 113 additions and 103 deletions

View File

@ -30,10 +30,10 @@
//
// HashTableDefinition(void *parent) {}
//
// size_t HashKey(int key) { return key >> 1; }
// size_t Hash(Foo *value) { return HashKey(value->bar); }
// bool Compare(int key, Foo *value) { return value->bar == key; }
// HashTableLink<Foo> *GetLink(Foo *value) { return value; }
// size_t HashKey(int key) const { return key >> 1; }
// size_t Hash(Foo *value) const { return HashKey(value->bar); }
// bool Compare(int key, Foo *value) const { return value->bar == key; }
// HashTableLink<Foo> *GetLink(Foo *value) const { return value; }
// };
template<typename Type>
@ -49,7 +49,7 @@ public:
typedef typename Definition::KeyType KeyType;
typedef typename Definition::ValueType ValueType;
static const size_t kMinimumSize = 32;
static const size_t kMinimumSize = 8;
// we use new [] / delete [] for allocation. If in the future this
// is revealed to be insufficient we can switch to a template based
@ -59,22 +59,18 @@ public:
// 50 / 256 = 19.53125%
OpenHashTable(size_t initialSize = kMinimumSize)
: fItemCount(0), fTable(NULL)
: fTableSize(0), fItemCount(0), fTable(NULL)
{
if (initialSize < kMinimumSize)
initialSize = kMinimumSize;
_Resize(initialSize);
if (initialSize > 0)
_Resize(initialSize);
}
OpenHashTable(typename Definition::ParentType *parent,
size_t initialSize = kMinimumSize)
: fDefinition(parent), fItemCount(0), fTable(NULL)
: fDefinition(parent), fTableSize(0), fItemCount(0), fTable(NULL)
{
if (initialSize < kMinimumSize)
initialSize = kMinimumSize;
_Resize(initialSize);
if (initialSize > 0)
_Resize(initialSize);
}
~OpenHashTable()
@ -82,10 +78,16 @@ public:
delete [] fTable;
}
status_t InitCheck() const { return fTable ? B_OK : B_NO_MEMORY; }
status_t InitCheck() const
{
return (fTableSize == 0 || fTable) ? B_OK : B_NO_MEMORY;
}
ValueType *Lookup(const KeyType &key) const
{
if (fTableSize == 0)
return NULL;
size_t index = fDefinition.HashKey(key) & (fTableSize - 1);
ValueType *slot = fTable[index];
@ -98,12 +100,16 @@ public:
return slot;
}
void Insert(ValueType *value)
status_t Insert(ValueType *value)
{
if (AutoExpand && fItemCount >= (fTableSize * 200 / 256))
if (fTableSize == 0) {
if (!_Resize(kMinimumSize))
return B_NO_MEMORY;
} else if (AutoExpand && fItemCount >= (fTableSize * 200 / 256))
_Resize(fTableSize * 2);
InsertUnchecked(value);
return B_OK;
}
void InsertUnchecked(ValueType *value)

View File

@ -116,17 +116,16 @@ class FragmentPacket {
net_timer fTimer;
};
class MulticastGroup {
public:
typedef MulticastGroupLink<IPv4Multicast> Link;
MulticastGroup(const in_addr &address);
status_t Deliver(net_protocol_module_info *module, net_buffer *buffer,
bool raw);
void Add(Link *link);
void Remove(Link *link);
void Add(IPv4Multicast::GroupState *groupState);
void Remove(IPv4Multicast::GroupState *groupState);
bool IsEmpty() const { return fLinks.IsEmpty(); }
struct HashDefinition {
@ -148,8 +147,7 @@ private:
// for g++ 2.95
friend class HashDefinition;
typedef DoublyLinkedListCLink<Link> LinkLink;
typedef DoublyLinkedList<Link, LinkLink> Links;
typedef DoublyLinkedList<IPv4Multicast::GroupState> Links;
in_addr fMulticastAddress;
Links fLinks;
@ -157,6 +155,8 @@ private:
HashTableLink<MulticastGroup> fLink;
};
class RawSocket : public DoublyLinkedListLinkImpl<RawSocket>, public DatagramSocket<> {
public:
RawSocket(net_socket *socket);
@ -456,26 +456,19 @@ status_t
MulticastGroup::Deliver(net_protocol_module_info *module, net_buffer *buffer,
bool deliverToRaw)
{
if (module->deliver_data == NULL)
return B_OK;
Links::Iterator iterator = fLinks.GetIterator();
while (iterator.HasNext()) {
Link *link = iterator.Next();
IPv4Multicast::GroupState *groupState = iterator.Next();
// we are pretty sure of this cast since multicast filters
// are installed with the IPv4 protocol reference
ipv4_protocol *protocol = (ipv4_protocol *)link->group->Socket();
if (deliverToRaw && protocol->raw == NULL)
if (deliverToRaw && groupState->Socket()->raw == NULL)
continue;
if (link->group->FilterAccepts(buffer)) {
if (groupState->FilterAccepts(buffer)) {
// as Multicast filters are installed with an IPv4 protocol
// reference, we need to go and find the appropriate instance
// related to the 'receiving protocol' with module 'module'.
net_protocol *proto = link->group->Socket();
net_protocol *proto = groupState->Socket()->socket->first_protocol;
while (proto && proto->module != module)
proto = proto->next;
@ -490,16 +483,16 @@ MulticastGroup::Deliver(net_protocol_module_info *module, net_buffer *buffer,
void
MulticastGroup::Add(Link *link)
MulticastGroup::Add(IPv4Multicast::GroupState *groupState)
{
fLinks.Add(link);
fLinks.Add(groupState);
}
void
MulticastGroup::Remove(Link *link)
MulticastGroup::Remove(IPv4Multicast::GroupState *groupState)
{
fLinks.Remove(link);
fLinks.Remove(groupState);
}
@ -700,6 +693,9 @@ static status_t
deliver_multicast(net_protocol_module_info *module, net_buffer *buffer,
bool deliverToRaw)
{
if (module->deliver_data == NULL)
return B_OK;
BenaphoreLocker _(sMulticastGroupsLock);
MulticastGroup *group = sMulticastGroups->Lookup(
@ -742,34 +738,34 @@ raw_receive_data(net_buffer *buffer)
status_t
IPv4Multicast::JoinGroup(const in_addr &groupAddr, MulticastGroup::Link *link)
IPv4Multicast::JoinGroup(GroupState *groupState)
{
BenaphoreLocker _(sMulticastGroupsLock);
MulticastGroup *group = sMulticastGroups->Lookup(groupAddr);
MulticastGroup *group = sMulticastGroups->Lookup(groupState->Address());
if (group == NULL) {
group = new (std::nothrow) MulticastGroup(groupAddr);
group = new (std::nothrow) MulticastGroup(groupState->Address());
if (group == NULL)
return B_NO_MEMORY;
sMulticastGroups->Insert(group);
}
group->Add(link);
group->Add(groupState);
return B_OK;
}
status_t
IPv4Multicast::LeaveGroup(const in_addr &groupAddr, MulticastGroup::Link *link)
IPv4Multicast::LeaveGroup(GroupState *groupState)
{
BenaphoreLocker _(sMulticastGroupsLock);
MulticastGroup *group = sMulticastGroups->Lookup(groupAddr);
MulticastGroup *group = sMulticastGroups->Lookup(groupState->Address());
if (group == NULL)
return ENOENT;
group->Remove(link);
group->Remove(groupState);
if (group->IsEmpty()) {
sMulticastGroups->Remove(group);
delete group;

View File

@ -93,7 +93,7 @@ MulticastGroupInterfaceState<AddressType>::_Remove(Source *state)
template<typename Addressing>
MulticastGroupState<Addressing>::MulticastGroupState(net_protocol *socket,
MulticastGroupState<Addressing>::MulticastGroupState(ProtocolType *socket,
const AddressType &address)
: fSocket(socket), fMulticastAddress(address), fFilterMode(kInclude)
{
@ -210,7 +210,7 @@ MulticastGroupState<Addressing>::FilterAccepts(net_buffer *buffer)
if (state == NULL)
return false;
bool has = state->Contains(*Addressing::AddressFromSockAddr(
bool has = state->Contains(Addressing::AddressFromSockAddr(
(sockaddr *)&buffer->source));
return (has && fFilterMode == kInclude) || (!has && fFilterMode == kExclude);
@ -248,8 +248,8 @@ MulticastGroupState<Addressing>::_RemoveInterface(InterfaceState *state)
template<typename Addressing>
MulticastFilter<Addressing>::MulticastFilter(net_protocol *socket)
: fParent(socket)
MulticastFilter<Addressing>::MulticastFilter(ProtocolType *socket)
: fParent(socket), fStates((size_t)0)
{
}
@ -262,6 +262,7 @@ MulticastFilter<Addressing>::~MulticastFilter()
GroupState *state = iterator.Next();
state->Clear();
ReturnGroup(state);
iterator.Rewind();
}
}
@ -270,28 +271,25 @@ template<typename Addressing> typename MulticastFilter<Addressing>::GroupState *
MulticastFilter<Addressing>::GetGroup(const AddressType &groupAddress,
bool create)
{
typename States::Iterator iterator = fStates.GetIterator();
GroupState *state = fStates.Lookup(groupAddress);
if (state)
return state;
while (iterator.HasNext()) {
GroupState *state = iterator.Next();
if (state->Address() == groupAddress)
return state;
}
if (create) {
state = new (nothrow) GroupState(fParent, groupAddress);
if (state) {
if (fStates.Insert(state) >= B_OK) {
if (Addressing::JoinGroup(state) >= B_OK)
return state;
if (!create)
return NULL;
fStates.Remove(state);
}
GroupState *state = new (nothrow) GroupState(fParent, groupAddress);
if (state) {
if (Addressing::JoinGroup(groupAddress, state->ProtocolLink()) < B_OK) {
delete state;
return NULL;
}
fStates.Add(state);
}
return state;
return NULL;
}
@ -299,8 +297,7 @@ template<typename Addressing> void
MulticastFilter<Addressing>::ReturnGroup(GroupState *group)
{
if (group->IsEmpty()) {
Addressing::LeaveGroup(group->Address(), group->ProtocolLink());
Addressing::LeaveGroup(group);
fStates.Remove(group);
delete group;
}

View File

@ -10,7 +10,7 @@
#define _PRIVATE_MULTICAST_H_
#include <util/DoublyLinkedList.h>
#include <util/list.h>
#include <util/OpenHashTable.h>
#include <netinet/in.h>
@ -21,30 +21,32 @@ struct net_protocol;
// This code is template'ized as it is reusable for IPv6
template<typename Addressing> class MulticastFilter;
template<typename Addressing> struct MulticastGroupLink;
template<typename Addressing> class MulticastGroupState;
// TODO move this elsewhere...
struct IPv4Multicast {
typedef struct in_addr AddressType;
typedef struct ipv4_protocol ProtocolType;
typedef MulticastGroupState<IPv4Multicast> GroupState;
static status_t JoinGroup(const in_addr &, MulticastGroupLink<IPv4Multicast> *);
static status_t LeaveGroup(const in_addr &, MulticastGroupLink<IPv4Multicast> *);
static status_t JoinGroup(GroupState *);
static status_t LeaveGroup(GroupState *);
static in_addr *AddressFromSockAddr(sockaddr *sockaddr)
{
return &((sockaddr_in *)sockaddr)->sin_addr;
}
static const in_addr &AddressFromSockAddr(const sockaddr *sockaddr)
{ return ((const sockaddr_in *)sockaddr)->sin_addr; }
static size_t HashAddress(const in_addr &address)
{ return address.s_addr; }
};
template<typename AddressType>
struct MulticastSource {
struct MulticastSource
: DoublyLinkedListLinkImpl< MulticastSource<AddressType> > {
AddressType address;
list_link link;
};
template<typename AddressType>
class MulticastGroupInterfaceState {
class MulticastGroupInterfaceState
: public DoublyLinkedListLinkImpl< MulticastGroupInterfaceState<AddressType> > {
public:
MulticastGroupInterfaceState(net_interface *interface);
~MulticastGroupInterfaceState();
@ -57,11 +59,9 @@ public:
bool Contains(const AddressType &address)
{ return _Get(address, false) != NULL; }
list_link link;
private:
typedef MulticastSource<AddressType> Source;
typedef DoublyLinkedListCLink<Source> SourceLink;
typedef DoublyLinkedList<Source, SourceLink> SourceList;
typedef DoublyLinkedList<Source> SourceList;
Source *_Get(const AddressType &address, bool create);
void _Remove(Source *state);
@ -72,20 +72,18 @@ private:
};
template<typename Addressing>
struct MulticastGroupLink {
MulticastGroupState<Addressing> *group;
list_link link;
};
template<typename Addressing>
class MulticastGroupState {
class MulticastGroupState
: public DoublyLinkedListLinkImpl< MulticastGroupState<Addressing> > {
public:
typedef MulticastGroupState<Addressing> ThisType;
typedef HashTableLink<ThisType> HashLink;
typedef typename Addressing::AddressType AddressType;
typedef typename Addressing::ProtocolType ProtocolType;
MulticastGroupState(net_protocol *parent, const AddressType &address);
MulticastGroupState(ProtocolType *parent, const AddressType &address);
~MulticastGroupState();
net_protocol *Socket() const { return fSocket; }
ProtocolType *Socket() const { return fSocket; }
const AddressType &Address() const { return fMulticastAddress; }
bool IsEmpty() const
@ -106,13 +104,27 @@ public:
bool FilterAccepts(net_buffer *buffer);
MulticastGroupLink<Addressing> *ProtocolLink() { return &fInternalLink; }
struct HashDefinition {
typedef void ParentType;
typedef typename MulticastGroupState::AddressType KeyType;
typedef typename MulticastGroupState::ThisType ValueType;
typedef typename MulticastGroupState::HashLink HashLink;
size_t HashKey(const KeyType &key) const
{ return Addressing::HashAddress(key); }
size_t Hash(ValueType *value) const
{ return HashKey(value->Address()); }
bool Compare(const KeyType &key, ValueType *value) const
{ return key == value->Address(); }
HashLink *GetLink(ValueType *value) const { return &value->fHashLink; }
};
list_link link;
private:
// for g++ 2.95
friend class HashDefinition;
typedef MulticastGroupInterfaceState<AddressType> InterfaceState;
typedef DoublyLinkedListCLink<InterfaceState> InterfaceStateLink;
typedef DoublyLinkedList<InterfaceState, InterfaceStateLink> InterfaceList;
typedef DoublyLinkedList<InterfaceState> InterfaceList;
InterfaceState *_GetInterface(net_interface *interface, bool create);
void _RemoveInterface(InterfaceState *state);
@ -122,34 +134,33 @@ private:
kExclude
};
net_protocol *fSocket;
ProtocolType *fSocket;
AddressType fMulticastAddress;
FilterMode fFilterMode;
InterfaceList fInterfaces;
MulticastGroupLink<Addressing> fInternalLink;
HashLink fHashLink;
};
template<typename Addressing>
class MulticastFilter {
public:
typedef typename Addressing::AddressType AddressType;
typedef typename Addressing::ProtocolType ProtocolType;
typedef MulticastGroupState<Addressing> GroupState;
MulticastFilter(net_protocol *parent);
MulticastFilter(ProtocolType *parent);
~MulticastFilter();
net_protocol *Parent() const { return fParent; }
ProtocolType *Parent() const { return fParent; }
GroupState *GetGroup(const AddressType &groupAddress, bool create);
void ReturnGroup(GroupState *group);
private:
typedef DoublyLinkedListCLink<GroupState> GroupStateLink;
typedef DoublyLinkedList<GroupState, GroupStateLink> States;
typedef typename GroupState::HashDefinition GroupHashDefinition;
typedef OpenHashTable<GroupHashDefinition> States;
net_protocol *fParent;
// TODO change this into an hash table or tree
ProtocolType *fParent;
States fStates;
};