* Moved the TCPConnection class into its own file.

* Added some missing result checks, mostly for allocations.
* Fixed a wrong precendence with the ?: operator
* Some minor cleanup.
* Renamed sBufferModule to gBufferModule - the header expects it to be a global,
  so it should be named like one.


git-svn-id: file:///srv/svn/repos/haiku/haiku/trunk@19178 a95241bf-73f2-0310-859d-f6bbb57e9c96
This commit is contained in:
Axel Dörfler 2006-11-02 17:27:13 +00:00
parent 262fc9994c
commit c35b04de31
14 changed files with 1100 additions and 983 deletions

View File

@ -9,11 +9,11 @@
#include <net_buffer.h> #include <net_buffer.h>
extern net_buffer_module_info *sBufferModule; extern net_buffer_module_info *gBufferModule;
class NetBufferModuleGetter { class NetBufferModuleGetter {
public: public:
static net_buffer_module_info *Get() { return sBufferModule; } static net_buffer_module_info *Get() { return gBufferModule; }
}; };
//! A class to retrieve and remove a header from a buffer //! A class to retrieve and remove a header from a buffer

View File

@ -98,7 +98,7 @@ struct arp_protocol : net_datalink_protocol {
static void arp_timer(struct net_timer *timer, void *data); static void arp_timer(struct net_timer *timer, void *data);
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
static net_stack_module_info *sStackModule; static net_stack_module_info *sStackModule;
static hash_table *sCache; static hash_table *sCache;
static benaphore sCacheLock; static benaphore sCacheLock;
@ -242,7 +242,7 @@ arp_update_entry(in_addr_t protocolAddress, sockaddr_dl *hardwareAddress,
} }
if (entry->request_buffer != NULL) { if (entry->request_buffer != NULL) {
sBufferModule->free(entry->request_buffer); gBufferModule->free(entry->request_buffer);
entry->request_buffer = NULL; entry->request_buffer = NULL;
} }
@ -401,7 +401,7 @@ arp_receive(void *cookie, net_buffer *buffer)
return B_ERROR; return B_ERROR;
} }
sBufferModule->free(buffer); gBufferModule->free(buffer);
return B_OK; return B_OK;
} }
@ -447,7 +447,7 @@ arp_timer(struct net_timer *timer, void *data)
if (entry->timer_state < ARP_STATE_LAST_REQUEST) { if (entry->timer_state < ARP_STATE_LAST_REQUEST) {
// we'll still need our buffer, so in order to prevent it being // we'll still need our buffer, so in order to prevent it being
// freed by a successful send, we need to clone it // freed by a successful send, we need to clone it
request = sBufferModule->clone(request, true); request = gBufferModule->clone(request, true);
if (request == NULL) { if (request == NULL) {
// cloning failed - that means we won't be able to send as // cloning failed - that means we won't be able to send as
// many requests as originally planned // many requests as originally planned
@ -460,7 +460,7 @@ arp_timer(struct net_timer *timer, void *data)
status_t status = entry->protocol->next->module->send_data( status_t status = entry->protocol->next->module->send_data(
entry->protocol->next, request); entry->protocol->next, request);
if (status < B_OK) if (status < B_OK)
sBufferModule->free(request); gBufferModule->free(request);
if (entry->timer_state == ARP_STATE_LAST_REQUEST) { if (entry->timer_state == ARP_STATE_LAST_REQUEST) {
// buffer has been freed on send // buffer has been freed on send
@ -535,7 +535,7 @@ arp_resolve(net_datalink_protocol *protocol, in_addr_t address, arp_entry **_ent
// prepare ARP request // prepare ARP request
entry->request_buffer = sBufferModule->create(256); entry->request_buffer = gBufferModule->create(256);
if (entry->request_buffer == NULL) { if (entry->request_buffer == NULL) {
// TODO: do something with the entry // TODO: do something with the entry
return B_NO_MEMORY; return B_NO_MEMORY;
@ -725,7 +725,7 @@ arp_init()
status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule); status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
goto err1; goto err1;

View File

@ -30,7 +30,7 @@ struct ethernet_frame_protocol : net_datalink_protocol {
static const uint8 kBroadcastAddress[6] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; static const uint8 kBroadcastAddress[6] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff};
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
int32 int32
@ -174,7 +174,7 @@ ethernet_frame_std_ops(int32 op, ...)
{ {
switch (op) { switch (op) {
case B_MODULE_INIT: case B_MODULE_INIT:
return get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); return get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
case B_MODULE_UNINIT: case B_MODULE_UNINIT:
put_module(NET_BUFFER_MODULE_NAME); put_module(NET_BUFFER_MODULE_NAME);
return B_OK; return B_OK;

View File

@ -28,7 +28,7 @@ struct loopback_frame_protocol : net_datalink_protocol {
}; };
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
int32 int32
@ -135,7 +135,7 @@ loopback_frame_std_ops(int32 op, ...)
{ {
switch (op) { switch (op) {
case B_MODULE_INIT: case B_MODULE_INIT:
return get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); return get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
case B_MODULE_UNINIT: case B_MODULE_UNINIT:
put_module(NET_BUFFER_MODULE_NAME); put_module(NET_BUFFER_MODULE_NAME);
return B_OK; return B_OK;

View File

@ -29,7 +29,7 @@ struct ethernet_device : net_device {
}; };
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
status_t status_t
@ -41,7 +41,7 @@ ethernet_init(const char *name, net_device **_device)
|| !strcmp(name, "/dev/net/userland_server")) || !strcmp(name, "/dev/net/userland_server"))
return B_BAD_VALUE; return B_BAD_VALUE;
status_t status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status_t status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
@ -134,13 +134,13 @@ dprintf("try to send ethernet packet of %lu bytes (flags %ld):\n", buffer->size,
if (buffer->size > device->frame_size || buffer->size < ETHER_HEADER_LENGTH) if (buffer->size > device->frame_size || buffer->size < ETHER_HEADER_LENGTH)
return B_BAD_VALUE; return B_BAD_VALUE;
if (sBufferModule->count_iovecs(buffer) > 1) { if (gBufferModule->count_iovecs(buffer) > 1) {
dprintf("scattered I/O is not yet supported by ethernet device.\n"); dprintf("scattered I/O is not yet supported by ethernet device.\n");
return B_NOT_SUPPORTED; return B_NOT_SUPPORTED;
} }
struct iovec iovec; struct iovec iovec;
sBufferModule->get_iovecs(buffer, &iovec, 1); gBufferModule->get_iovecs(buffer, &iovec, 1);
dump_block((const char *)iovec.iov_base, buffer->size, " "); dump_block((const char *)iovec.iov_base, buffer->size, " ");
ssize_t bytesWritten = write(device->fd, iovec.iov_base, iovec.iov_len); ssize_t bytesWritten = write(device->fd, iovec.iov_base, iovec.iov_len);
@ -153,7 +153,7 @@ dprintf("sent: %ld\n", bytesWritten);
device->stats.send.packets++; device->stats.send.packets++;
device->stats.send.bytes += bytesWritten; device->stats.send.bytes += bytesWritten;
sBufferModule->free(buffer); gBufferModule->free(buffer);
return B_OK; return B_OK;
} }
@ -164,7 +164,7 @@ ethernet_receive_data(net_device *_device, net_buffer **_buffer)
ethernet_device *device = (ethernet_device *)_device; ethernet_device *device = (ethernet_device *)_device;
// TODO: better header space // TODO: better header space
net_buffer *buffer = sBufferModule->create(256); net_buffer *buffer = gBufferModule->create(256);
if (buffer == NULL) if (buffer == NULL)
return ENOBUFS; return ENOBUFS;
@ -176,7 +176,7 @@ ethernet_receive_data(net_device *_device, net_buffer **_buffer)
ssize_t bytesRead; ssize_t bytesRead;
void *data; void *data;
status_t status = sBufferModule->append_size(buffer, device->frame_size, &data); status_t status = gBufferModule->append_size(buffer, device->frame_size, &data);
if (status == B_OK && data == NULL) { if (status == B_OK && data == NULL) {
dprintf("scattered I/O is not yet supported by ethernet device.\n"); dprintf("scattered I/O is not yet supported by ethernet device.\n");
status = B_NOT_SUPPORTED; status = B_NOT_SUPPORTED;
@ -191,7 +191,7 @@ ethernet_receive_data(net_device *_device, net_buffer **_buffer)
goto err; goto err;
} }
status = sBufferModule->trim(buffer, bytesRead); status = gBufferModule->trim(buffer, bytesRead);
if (status < B_OK) { if (status < B_OK) {
device->stats.receive.dropped++; device->stats.receive.dropped++;
goto err; goto err;
@ -204,7 +204,7 @@ ethernet_receive_data(net_device *_device, net_buffer **_buffer)
return B_OK; return B_OK;
err: err:
sBufferModule->free(buffer); gBufferModule->free(buffer);
return status; return status;
} }

View File

@ -25,7 +25,7 @@ struct loopback_device : net_device {
}; };
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
static struct net_stack_module_info *sStackModule; static struct net_stack_module_info *sStackModule;
@ -72,7 +72,7 @@ loopback_init(const char *name, net_device **_device)
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
goto err1; goto err1;

View File

@ -54,7 +54,7 @@ struct icmp_protocol : net_protocol {
static net_stack_module_info *sStackModule; static net_stack_module_info *sStackModule;
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
net_protocol * net_protocol *
@ -214,8 +214,8 @@ icmp_receive_data(net_buffer *buffer)
dprintf(" got type %u, code %u, checksum %u\n", header.type, header.code, dprintf(" got type %u, code %u, checksum %u\n", header.type, header.code,
ntohs(header.checksum)); ntohs(header.checksum));
dprintf(" computed checksum: %ld\n", sBufferModule->checksum(buffer, 0, buffer->size, true)); dprintf(" computed checksum: %ld\n", gBufferModule->checksum(buffer, 0, buffer->size, true));
if (sBufferModule->checksum(buffer, 0, buffer->size, true) != 0) if (gBufferModule->checksum(buffer, 0, buffer->size, true) != 0)
return B_BAD_DATA; return B_BAD_DATA;
switch (header.type) { switch (header.type) {
@ -232,7 +232,7 @@ icmp_receive_data(net_buffer *buffer)
if (domain == NULL || domain->module == NULL) if (domain == NULL || domain->module == NULL)
break; break;
net_buffer *reply = sBufferModule->duplicate(buffer); net_buffer *reply = gBufferModule->duplicate(buffer);
if (reply == NULL) if (reply == NULL)
return B_NO_MEMORY; return B_NO_MEMORY;
@ -242,26 +242,26 @@ icmp_receive_data(net_buffer *buffer)
// There already is an ICMP header, and we'll reuse it // There already is an ICMP header, and we'll reuse it
icmp_header *header; icmp_header *header;
status_t status = sBufferModule->direct_access(reply, status_t status = gBufferModule->direct_access(reply,
0, sizeof(icmp_header), (void **)&header); 0, sizeof(icmp_header), (void **)&header);
if (status == B_OK) { if (status == B_OK) {
header->type = ICMP_TYPE_ECHO_REPLY; header->type = ICMP_TYPE_ECHO_REPLY;
header->code = 0; header->code = 0;
header->checksum = 0; header->checksum = 0;
header->checksum = sBufferModule->checksum(reply, 0, reply->size, true); header->checksum = gBufferModule->checksum(reply, 0, reply->size, true);
} }
if (status == B_OK) if (status == B_OK)
status = domain->module->send_data(NULL, reply); status = domain->module->send_data(NULL, reply);
if (status < B_OK) { if (status < B_OK) {
sBufferModule->free(reply); gBufferModule->free(reply);
return status; return status;
} }
} }
} }
sBufferModule->free(buffer); gBufferModule->free(buffer);
return B_OK; return B_OK;
} }
@ -293,7 +293,7 @@ icmp_std_ops(int32 op, ...)
status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule); status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) { if (status < B_OK) {
put_module(NET_STACK_MODULE_NAME); put_module(NET_STACK_MODULE_NAME);
return status; return status;

View File

@ -146,7 +146,7 @@ extern net_protocol_module_info gIPv4Module;
static struct net_domain *sDomain; static struct net_domain *sDomain;
static net_datalink_module_info *sDatalinkModule; static net_datalink_module_info *sDatalinkModule;
static net_stack_module_info *sStackModule; static net_stack_module_info *sStackModule;
struct net_buffer_module_info *sBufferModule; struct net_buffer_module_info *gBufferModule;
static int32 sPacketID; static int32 sPacketID;
static RawSocketList sRawSockets; static RawSocketList sRawSockets;
static benaphore sRawSocketsLock; static benaphore sRawSocketsLock;
@ -190,7 +190,7 @@ RawSocket::Read(size_t numBytes, uint32 flags, bigtime_t timeout,
if (numBytes < buffer->size) { if (numBytes < buffer->size) {
// discard any data behind the amount requested // discard any data behind the amount requested
sBufferModule->trim(buffer, numBytes); gBufferModule->trim(buffer, numBytes);
} }
*_buffer = buffer; *_buffer = buffer;
@ -209,7 +209,7 @@ status_t
RawSocket::Write(net_buffer *source) RawSocket::Write(net_buffer *source)
{ {
// we need to make a clone for that buffer and pass it to the socket // we need to make a clone for that buffer and pass it to the socket
net_buffer *buffer = sBufferModule->clone(source, false); net_buffer *buffer = gBufferModule->clone(source, false);
TRACE(("ipv4::RawSocket::Write(): cloned buffer %p\n", buffer)); TRACE(("ipv4::RawSocket::Write(): cloned buffer %p\n", buffer));
if (buffer == NULL) if (buffer == NULL)
return B_NO_MEMORY; return B_NO_MEMORY;
@ -218,7 +218,7 @@ RawSocket::Write(net_buffer *source)
if (status >= B_OK) if (status >= B_OK)
sStackModule->notify_socket(fSocket, B_SELECT_READ, BytesAvailable()); sStackModule->notify_socket(fSocket, B_SELECT_READ, BytesAvailable());
else else
sBufferModule->free(buffer); gBufferModule->free(buffer);
return status; return status;
} }
@ -246,7 +246,7 @@ FragmentPacket::~FragmentPacket()
ipv4_fragment *fragment; ipv4_fragment *fragment;
while ((fragment = fFragments.RemoveHead()) != NULL) { while ((fragment = fFragments.RemoveHead()) != NULL) {
if (fragment->buffer != NULL) if (fragment->buffer != NULL)
sBufferModule->free(fragment->buffer); gBufferModule->free(fragment->buffer);
delete fragment; delete fragment;
} }
} }
@ -283,7 +283,7 @@ FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
if (previous != NULL && previous->start <= start && previous->end >= end) { if (previous != NULL && previous->start <= start && previous->end >= end) {
// we do, so we can just drop this fragment // we do, so we can just drop this fragment
sBufferModule->free(buffer); gBufferModule->free(buffer);
return B_OK; return B_OK;
} }
@ -293,12 +293,12 @@ FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
if (previous != NULL && previous->end > start) { if (previous != NULL && previous->end > start) {
TRACE((" remove header %d bytes\n", previous->end - start)); TRACE((" remove header %d bytes\n", previous->end - start));
sBufferModule->remove_header(buffer, previous->end - start); gBufferModule->remove_header(buffer, previous->end - start);
start = previous->end; start = previous->end;
} }
if (next != NULL && next->start < end) { if (next != NULL && next->start < end) {
TRACE((" remove trailer %d bytes\n", next->start - end)); TRACE((" remove trailer %d bytes\n", next->start - end));
sBufferModule->remove_trailer(buffer, next->start - end); gBufferModule->remove_trailer(buffer, next->start - end);
end = next->start; end = next->start;
} }
@ -308,7 +308,7 @@ FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
// report an error (in which case we're not responsible for freeing it) // report an error (in which case we're not responsible for freeing it)
if (previous != NULL && previous->end == start) { if (previous != NULL && previous->end == start) {
status_t status = sBufferModule->merge(buffer, previous->buffer, false); status_t status = gBufferModule->merge(buffer, previous->buffer, false);
TRACE((" merge previous: %s\n", strerror(status))); TRACE((" merge previous: %s\n", strerror(status)));
if (status < B_OK) if (status < B_OK)
return status; return status;
@ -328,7 +328,7 @@ FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer *buffer,
return B_OK; return B_OK;
} else if (next != NULL && next->start == end) { } else if (next != NULL && next->start == end) {
status_t status = sBufferModule->merge(buffer, next->buffer, true); status_t status = gBufferModule->merge(buffer, next->buffer, true);
TRACE((" merge next: %s\n", strerror(status))); TRACE((" merge next: %s\n", strerror(status)));
if (status < B_OK) if (status < B_OK)
return status; return status;
@ -392,10 +392,10 @@ FragmentPacket::Reassemble(net_buffer *to)
if (buffer != NULL) { if (buffer != NULL) {
status_t status; status_t status;
if (to == fragment->buffer) { if (to == fragment->buffer) {
status = sBufferModule->merge(fragment->buffer, buffer, false); status = gBufferModule->merge(fragment->buffer, buffer, false);
buffer = fragment->buffer; buffer = fragment->buffer;
} else } else
status = sBufferModule->merge(buffer, fragment->buffer, true); status = gBufferModule->merge(buffer, fragment->buffer, true);
if (status < B_OK) if (status < B_OK)
return status; return status;
} else } else
@ -534,7 +534,7 @@ reassemble_fragments(const ipv4_header &header, net_buffer **_buffer)
// Remove header unless this is the first fragment // Remove header unless this is the first fragment
if (start != 0) if (start != 0)
sBufferModule->remove_header(buffer, header.HeaderLength()); gBufferModule->remove_header(buffer, header.HeaderLength());
status = packet->AddFragment(start, end, buffer, lastFragment); status = packet->AddFragment(start, end, buffer, lastFragment);
if (status != B_OK) if (status != B_OK)
@ -580,7 +580,7 @@ send_fragments(ipv4_protocol *protocol, struct net_route *route,
uint32 fragmentOffset = 0; uint32 fragmentOffset = 0;
status_t status = B_OK; status_t status = B_OK;
net_buffer *headerBuffer = sBufferModule->split(buffer, headerLength); net_buffer *headerBuffer = gBufferModule->split(buffer, headerLength);
if (headerBuffer == NULL) if (headerBuffer == NULL)
return B_NO_MEMORY; return B_NO_MEMORY;
@ -610,7 +610,7 @@ send_fragments(ipv4_protocol *protocol, struct net_route *route,
net_buffer *fragmentBuffer; net_buffer *fragmentBuffer;
if (!lastFragment) { if (!lastFragment) {
fragmentBuffer = sBufferModule->split(buffer, fragmentLength); fragmentBuffer = gBufferModule->split(buffer, fragmentLength);
fragmentOffset += fragmentLength; fragmentOffset += fragmentLength;
} else } else
fragmentBuffer = buffer; fragmentBuffer = buffer;
@ -621,7 +621,7 @@ send_fragments(ipv4_protocol *protocol, struct net_route *route,
} }
// copy header to fragment // copy header to fragment
status = sBufferModule->prepend(fragmentBuffer, header, headerLength); status = gBufferModule->prepend(fragmentBuffer, header, headerLength);
// send fragment // send fragment
if (status == B_OK) if (status == B_OK)
@ -633,12 +633,12 @@ send_fragments(ipv4_protocol *protocol, struct net_route *route,
} }
if (status < B_OK) { if (status < B_OK) {
sBufferModule->free(fragmentBuffer); gBufferModule->free(fragmentBuffer);
break; break;
} }
} }
sBufferModule->free(headerBuffer); gBufferModule->free(headerBuffer);
return status; return status;
} }
@ -884,7 +884,7 @@ ipv4_send_routed_data(net_protocol *_protocol, struct net_route *route,
// always use the actual used source address // always use the actual used source address
header.destination = ((sockaddr_in *)&buffer->destination)->sin_addr.s_addr; header.destination = ((sockaddr_in *)&buffer->destination)->sin_addr.s_addr;
header.checksum = sBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true); header.checksum = gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true);
dump_ipv4_header(header); dump_ipv4_header(header);
bufferHeader.Detach(); bufferHeader.Detach();
@ -892,8 +892,8 @@ ipv4_send_routed_data(net_protocol *_protocol, struct net_route *route,
} }
TRACE(("header chksum: %ld, buffer checksum: %ld\n", TRACE(("header chksum: %ld, buffer checksum: %ld\n",
sBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true), gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true),
sBufferModule->checksum(buffer, 0, buffer->size, true))); gBufferModule->checksum(buffer, 0, buffer->size, true)));
TRACE(("destination-IP: buffer=%p addr=%p %08lx\n", buffer, &buffer->destination, TRACE(("destination-IP: buffer=%p addr=%p %08lx\n", buffer, &buffer->destination,
ntohl(((sockaddr_in *)&buffer->destination)->sin_addr.s_addr))); ntohl(((sockaddr_in *)&buffer->destination)->sin_addr.s_addr)));
@ -1007,7 +1007,7 @@ ipv4_receive_data(net_buffer *buffer)
return B_BAD_DATA; return B_BAD_DATA;
// TODO: would be nice to have a direct checksum function somewhere // TODO: would be nice to have a direct checksum function somewhere
if (sBufferModule->checksum(buffer, 0, headerLength, true) != 0) if (gBufferModule->checksum(buffer, 0, headerLength, true) != 0)
return B_BAD_DATA; return B_BAD_DATA;
struct sockaddr_in &source = *(struct sockaddr_in *)&buffer->source; struct sockaddr_in &source = *(struct sockaddr_in *)&buffer->source;
@ -1036,7 +1036,7 @@ ipv4_receive_data(net_buffer *buffer)
uint8 protocol = buffer->protocol = header.protocol; uint8 protocol = buffer->protocol = header.protocol;
// remove any trailing/padding data // remove any trailing/padding data
status_t status = sBufferModule->trim(buffer, packetLength); status_t status = gBufferModule->trim(buffer, packetLength);
if (status < B_OK) if (status < B_OK)
return status; return status;
@ -1067,7 +1067,7 @@ ipv4_receive_data(net_buffer *buffer)
raw_receive_data(buffer); raw_receive_data(buffer);
} }
sBufferModule->remove_header(buffer, headerLength); gBufferModule->remove_header(buffer, headerLength);
// the header is of variable size and may include IP options // the header is of variable size and may include IP options
// (that we ignore for now) // (that we ignore for now)
@ -1111,7 +1111,7 @@ init_ipv4()
status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule); status_t status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
goto err1; goto err1;
status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule); status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule);

View File

@ -14,6 +14,7 @@ UsePrivateHeaders kernel net ;
KernelAddon tcp : KernelAddon tcp :
tcp.cpp tcp.cpp
TCPConnection.cpp
; ;
# Installation # Installation

View File

@ -0,0 +1,811 @@
/*
* Copyright 2006, Haiku, Inc. All Rights Reserved.
* Distributed under the terms of the MIT License.
*
* Authors:
* Andrew Galante, haiku.galante@gmail.com
* Axel Dörfler, axeld@pinc-software.de
*/
#include "TCPConnection.h"
#include <net_buffer.h>
#include <net_datalink.h>
#include <KernelExport.h>
#include <util/list.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <new>
#include <stdlib.h>
#include <string.h>
#include <lock.h>
#include <util/AutoLock.h>
#include <util/khash.h>
#include <NetBufferUtilities.h>
#include <NetUtilities.h>
#define TRACE_TCP
#ifdef TRACE_TCP
# define TRACE(x) dprintf x
#else
# define TRACE(x)
#endif
// Initial estimate for packet round trip time (RTT)
#define TCP_INITIAL_RTT 120000000LL
// Estimate for Maximum segment lifetime in the internet
#define TCP_MAX_SEGMENT_LIFETIME (2 * TCP_INITIAL_RTT)
// keep maximum buffer sizes < max net_buffer size for now
#define TCP_MAX_SEND_BUF 1024
#define TCP_MAX_RECV_BUF TCP_MAX_SEND_BUF
#define TCP_IS_GOOD_ACK(x) (fLastByteAckd > fNextByteToSend ? \
((x) >= fLastByteAckd || (x) <= fNextByteToSend) : \
((x) >= fLastByteAckd && (x) <= fNextByteToSend))
#define TCP_IS_GOOD_SEQ(x,y) (fNextByteToRead < fNextByteToRead + TCP_MAX_RECV_BUF ? \
(x) >= fNextByteToRead && (x) + (y) <= fNextByteToRead + TCP_MAX_RECV_BUF : \
((x) >= fNextByteToRead || (x) <= fNextByteToRead + TCP_MAX_RECV_BUF) && \
((x) + (y) >= fNextByteToRead || (x) + (y) <= fNextByteToRead + TCP_MAX_RECV_BUF))
struct tcp_segment {
struct list_link link;
net_buffer *buffer;
bigtime_t time;
uint32 sequence;
net_timer timer;
bool timed_out;
tcp_segment(net_buffer *buffer, uint32 sequenceNumber, bigtime_t timeout);
~tcp_segment();
};
tcp_segment::tcp_segment(net_buffer *_buffer, uint32 sequenceNumber, bigtime_t timeout)
:
buffer(_buffer),
time(system_time()),
sequence(sequenceNumber),
timed_out(false)
{
if (timeout > 0) {
gStackModule->init_timer(&timer, &TCPConnection::ResendSegment, this);
gStackModule->set_timer(&timer, timeout);
}
}
tcp_segment::~tcp_segment()
{
gStackModule->set_timer(&timer, -1);
}
// #pragma mark -
TCPConnection::TCPConnection(net_socket *socket)
:
fLastByteAckd(0), //system_time()),
fNextByteToSend(fLastByteAckd),
fNextByteToWrite(fLastByteAckd + 1),
fNextByteToRead(0),
fNextByteExpected(0),
fLastByteReceived(0),
fAvgRTT(TCP_INITIAL_RTT),
fSendBuffer(NULL),
fReceiveBuffer(NULL),
fState(CLOSED),
fError(B_OK),
fRoute(NULL)
{
benaphore_init(&fLock, "TCPConnection");
gStackModule->init_timer(&fTimer, _TimeWait, this);
list_init(&fReorderQueue);
list_init(&fWaitQueue);
}
TCPConnection::~TCPConnection()
{
benaphore_destroy(&fLock);
}
status_t
TCPConnection::Open()
{
TRACE(("%p.Open()\n", this));
if (gAddressModule == NULL)
return B_ERROR;
TRACE(("TCP: Open(): Using Address Module %p\n", gAddressModule));
BenaphoreLocker lock(&fLock);
gAddressModule->set_to_empty_address((sockaddr *)&socket->address);
gAddressModule->set_port((sockaddr *)&socket->address, 0);
gAddressModule->set_to_empty_address((sockaddr *)&socket->peer);
gAddressModule->set_port((sockaddr *)&socket->peer, 0);
return B_OK;
}
status_t
TCPConnection::Close()
{
BenaphoreLocker lock(&fLock);
TRACE(("TCP:%p.Close()\n", this));
if (fState == SYN_SENT || fState == LISTEN) {
fState = CLOSED;
return B_OK;
}
tcp_state nextState = CLOSED;
if (fState == SYN_RCVD || fState == ESTABLISHED)
nextState = FIN_WAIT1;
if (fState == CLOSE_WAIT)
nextState = LAST_ACK;
status_t status = _SendQueuedData(TCP_FLG_FIN | TCP_FLG_ACK, false);
if (status != B_OK)
return status;
fState = nextState;
TRACE(("TCP: %p.Close(): Entering state %d\n", this, fState));
//do i need to wait until fState returns to CLOSED?
return B_OK;
}
status_t
TCPConnection::Free()
{
TRACE(("TCP:%p.Free()\n", this));
BenaphoreLocker hashLock(&gConnectionLock);
BenaphoreLocker lock(&fLock);
tcp_connection_key key;
key.local = (sockaddr *)&socket->address;
key.peer = (sockaddr *)&socket->peer;
if (hash_lookup(gConnectionHash, &key) != NULL) {
return hash_remove(gConnectionHash, (void *)this);
}
return B_OK;
}
/*!
Creates and sends a SYN packet to /a address
*/
status_t
TCPConnection::Connect(const struct sockaddr *address)
{
TRACE(("TCP:%p.Connect() on address %s\n", this,
AddressString(gDomain, address, true).Data()));
if (address->sa_family != AF_INET)
return EAFNOSUPPORT;
benaphore_lock(&gConnectionLock); // want to release lock later, so no autolock
BenaphoreLocker lock(&fLock);
// Can only call Connect from CLOSED or LISTEN states
// otherwise connection is considered already connected
if (fState != CLOSED && fState != LISTEN) {
benaphore_unlock(&gConnectionLock);
return EISCONN;
}
TRACE(("TCP: Connect(): in state %d\n", fState));
// get a net_route if there isn't one
if (fRoute == NULL) {
fRoute = gDatalinkModule->get_route(gDomain, (sockaddr *)address);
TRACE(("TCP: Connect(): Using Route %p\n", fRoute));
if (fRoute == NULL) {
benaphore_unlock(&gConnectionLock);
return ENETUNREACH;
}
}
// need to associate this connection with a real address, not INADDR_ANY
if (gAddressModule->is_empty_address((sockaddr *)&socket->address, false)) {
TRACE(("TCP: Connect(): Local Address is INADDR_ANY\n"));
gAddressModule->set_to((sockaddr *)&socket->address, (sockaddr *)fRoute->interface->address);
// since most stacks terminate connections from port 0
// use port 40000 for now. This should be moved to Bind(), and Bind() called before Connect().
gAddressModule->set_port((sockaddr *)&socket->address, htons(40000));
}
// make sure connection does not already exist
tcp_connection_key key;
key.local = (sockaddr *)&socket->address;
key.peer = address;
if (hash_lookup(gConnectionHash, &key) != NULL) {
benaphore_unlock(&gConnectionLock);
return EADDRINUSE;
}
TRACE(("TCP: Connect(): connecting...\n"));
status_t status;
gAddressModule->set_to((sockaddr *)&socket->peer, address);
status = hash_insert(gConnectionHash, (void *)this);
if (status != B_OK) {
TRACE(("TCP: Connect(): Error inserting connection into hash!\n"));
benaphore_unlock(&gConnectionLock);
return status;
}
// done manipulating the hash, release the lock
benaphore_unlock(&gConnectionLock);
TRACE(("TCP: Connect(): starting 3-way handshake...\n"));
// send SYN
status = _SendQueuedData(TCP_FLG_SYN, false);
if (status != B_OK)
return status;
fState = SYN_SENT;
// TODO: Should Connect() not return until 3-way handshake is complete?
TRACE(("TCP: Connect(): Connection complete\n"));
return B_OK;
}
status_t
TCPConnection::Accept(struct net_socket **_acceptedSocket)
{
TRACE(("TCP:%p.Accept()\n", this));
return B_ERROR;
}
status_t
TCPConnection::Bind(sockaddr *address)
{
TRACE(("TCP:%p.Bind() on address %s\n", this,
AddressString(gDomain, address, true).Data()));
if (address->sa_family != AF_INET)
return EAFNOSUPPORT;
BenaphoreLocker hashLock(&gConnectionLock);
BenaphoreLocker lock(&fLock);
// let IP check whether there is an interface that supports the given address:
status_t status = next->module->bind(next, address);
if (status < B_OK)
return status;
gAddressModule->set_to((sockaddr *)&socket->address, address);
// for now, leave port=0. TCP should still work 1 connection at a time
if (0) { //gAddressModule->get_port((sockaddr *)&socket->address) == 0) {
// assign ephemeral port
} else {
// TODO: Check for Socket flags
tcp_connection_key key;
key.peer = (sockaddr *)&socket->peer;
key.local = (sockaddr *)&socket->address;
if (hash_lookup(gConnectionHash, &key) == NULL) {
hash_insert(gConnectionHash, (void *)this);
} else
return EADDRINUSE;
}
return B_OK;
}
status_t
TCPConnection::Unbind(struct sockaddr *address)
{
TRACE(("TCP:%p.Unbind()\n", this ));
BenaphoreLocker hashLock(&gConnectionLock);
BenaphoreLocker lock(&fLock);
status_t status = hash_remove(gConnectionHash, (void *)this);
if (status != B_OK)
return status;
gAddressModule->set_to_empty_address((sockaddr *)&socket->address);
gAddressModule->set_port((sockaddr *)&socket->address, 0);
return B_OK;
}
status_t
TCPConnection::Listen(int count)
{
TRACE(("TCP:%p.Listen()\n", this));
BenaphoreLocker lock(&fLock);
if (fState != CLOSED)
return B_ERROR;
fState = LISTEN;
return B_OK;
}
status_t
TCPConnection::Shutdown(int direction)
{
TRACE(("TCP:%p.Shutdown()\n", this));
return B_ERROR;
}
/*!
Puts data contained in \a buffer into send buffer
*/
status_t
TCPConnection::SendData(net_buffer *buffer)
{
TRACE(("TCP:%p.SendData()\n", this));
size_t bufferSize = buffer->size;
BenaphoreLocker lock(&fLock);
if (fSendBuffer != NULL) {
status_t status = gBufferModule->merge(fSendBuffer, buffer, true);
if (status != B_OK)
return status;
} else
fSendBuffer = buffer;
fNextByteToWrite += bufferSize;
return _SendQueuedData(TCP_FLG_ACK, false);
}
status_t
TCPConnection::SendRoutedData(net_route *route, net_buffer *buffer)
{
TRACE(("TCP:%p.SendRoutedData()\n", this));
{
BenaphoreLocker lock(&fLock);
fRoute = route;
}
return SendData(buffer);
}
size_t
TCPConnection::SendAvailable()
{
TRACE(("TCP:%p.SendAvailable()\n", this));
BenaphoreLocker lock(&fLock);
if (fSendBuffer != NULL)
return TCP_MAX_SEND_BUF - fSendBuffer->size;
return TCP_MAX_SEND_BUF;
}
status_t
TCPConnection::ReadData(size_t numBytes, uint32 flags, net_buffer** _buffer)
{
TRACE(("TCP:%p.ReadData()\n", this));
BenaphoreLocker lock(&fLock);
// must be in a synchronous state
if (fState != ESTABLISHED || fState != FIN_WAIT1 || fState != FIN_WAIT2) {
// is this correct semantics?
dprintf(" TCP state = %d\n", fState);
return B_ERROR;
}
dprintf(" TCP error = %ld\n", fError);
if (fError != B_OK)
return fError;
if (fReceiveBuffer->size < numBytes)
numBytes = fReceiveBuffer->size;
*_buffer = gBufferModule->split(fReceiveBuffer, numBytes);
if (*_buffer == NULL)
return B_NO_MEMORY;
return B_OK;
}
size_t
TCPConnection::ReadAvailable()
{
TRACE(("TCP:%p.ReadAvailable()\n", this));
BenaphoreLocker lock(&fLock);
if (fReceiveBuffer != NULL)
return fReceiveBuffer->size;
return 0;
}
/*!
You must hold the connection's lock when calling this method
*/
status_t
TCPConnection::_EnqueueReceivedData(net_buffer *buffer, uint32 sequence)
{
TRACE(("TCP:%p.EnqueueReceivedData(%p, %lu)\n", this, buffer, sequence));
status_t status;
if (sequence == fNextByteExpected) {
// first check if the received buffer meets up with the first
// segment in the ReorderQueue
tcp_segment *next;
while ((next = (tcp_segment *)list_get_first_item(&fReorderQueue)) != NULL) {
if (sequence + buffer->size >= next->sequence) {
if (sequence + buffer->size > next->sequence) {
status = gBufferModule->trim(buffer, sequence - next->sequence);
if (status != B_OK)
return status;
}
status = gBufferModule->merge(buffer, next->buffer, true);
if (status != B_OK)
return status;
list_remove_item(&fReorderQueue, next);
delete next;
} else
break;
}
fNextByteExpected += buffer->size;
if (fReceiveBuffer == NULL)
fReceiveBuffer = buffer;
else {
status = gBufferModule->merge(fReceiveBuffer, buffer, true);
if (status < B_OK) {
fNextByteExpected -= buffer->size;
return status;
}
}
} else {
// add this buffer into the ReorderQueue in the correct place
// creating a new tcp_segment if necessary
tcp_segment *next = NULL;
do {
next = (tcp_segment *)list_get_next_item(&fReorderQueue, next);
if (next != NULL && next->sequence < sequence)
continue;
if (next != NULL && sequence + buffer->size >= next->sequence) {
// merge the new buffer with the next buffer
if (sequence + buffer->size > next->sequence) {
status = gBufferModule->trim(buffer, sequence - next->sequence);
if (status != B_OK)
return status;
}
status = gBufferModule->merge(buffer, next->buffer, true);
if (status != B_OK)
return status;
next->buffer = buffer;
next->sequence = sequence;
break;
}
tcp_segment *segment = new(std::nothrow) tcp_segment(buffer, sequence, -1);
if (segment == NULL)
return B_NO_MEMORY;
if (next == NULL)
list_add_item(&fReorderQueue, segment);
else
list_insert_item_before(&fReorderQueue, next, segment);
} while (next != NULL);
}
return B_OK;
}
status_t
TCPConnection::ReceiveData(net_buffer *buffer)
{
BenaphoreLocker lock(&fLock);
TRACE(("TCP:%p.ReceiveData()\n", this));
NetBufferHeader<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() < B_OK)
return bufferHeader.Status();
tcp_header &header = bufferHeader.Data();
uint16 flags = 0x0;
tcp_state nextState = fState;
status_t status = B_OK;
uint32 byteAckd = ntohl(header.acknowledge_num);
uint32 byteRcvd = ntohl(header.sequence_num);
uint32 headerLength = (uint32)header.header_length << 2;
uint32 payloadLength = buffer->size - headerLength;
TRACE(("TCP: ReceiveData(): Connection in state %d received packet %p with flags %X!\n", fState, buffer, header.flags));
switch (fState) {
case CLOSED:
case TIME_WAIT:
gBufferModule->free(buffer);
if (header.flags & TCP_FLG_ACK)
return _Reset(byteAckd, 0);
return _Reset(0, byteRcvd + payloadLength);
case LISTEN:
// if packet is SYN, spawn new TCPConnection in SYN_RCVD state
// and add it to the Connection Queue. The new TCPConnection
// must continue the handshake by replying with SYN+ACK. Any
// data in the packet must go into the new TCPConnection's receive
// buffer.
// Otherwise, RST+ACK is sent.
// The current TCPConnection always remains in LISTEN state.
return B_ERROR;
case SYN_SENT:
if (header.flags & TCP_FLG_RST) {
fError = ECONNREFUSED;
fState = CLOSED;
return B_ERROR;
}
if (header.flags & TCP_FLG_ACK && !TCP_IS_GOOD_ACK(byteAckd))
return _Reset(byteAckd, 0);
if (header.flags & TCP_FLG_SYN) {
fNextByteToRead = fNextByteExpected = ntohl(header.sequence_num) + 1;
flags |= TCP_FLG_ACK;
fLastByteAckd = byteAckd;
// cancel resend of this segment
if (header.flags & TCP_FLG_ACK)
nextState = ESTABLISHED;
else {
nextState = SYN_RCVD;
flags |= TCP_FLG_SYN;
}
}
break;
case SYN_RCVD:
if (header.flags & TCP_FLG_ACK && TCP_IS_GOOD_ACK(byteAckd))
fState = ESTABLISHED;
else
_Reset(byteAckd, 0);
break;
default:
// In a synchronized state.
// first check that the received sequence number is good
if (TCP_IS_GOOD_SEQ(byteRcvd, payloadLength)) {
// If a valid RST was received, terminate the connection.
if (header.flags & TCP_FLG_RST) {
fError = ECONNREFUSED;
fState = CLOSED;
return B_ERROR;
}
if (header.flags & TCP_FLG_ACK && TCP_IS_GOOD_ACK(byteAckd) ) {
fLastByteAckd = byteAckd;
if (fLastByteAckd == fNextByteToWrite) {
if (fState == LAST_ACK ) {
nextState = CLOSED;
status = hash_remove(gConnectionHash, this);
if (status != B_OK)
return status;
}
if (fState == CLOSING) {
nextState = TIME_WAIT;
status = hash_remove(gConnectionHash, this);
if (status != B_OK)
return status;
}
if (fState == FIN_WAIT1) {
nextState = FIN_WAIT2;
}
}
}
if (header.flags & TCP_FLG_FIN) {
// other side is closing connection. change states
switch (fState) {
case ESTABLISHED:
nextState = CLOSE_WAIT;
fNextByteExpected++;
break;
case FIN_WAIT2:
nextState = TIME_WAIT;
fNextByteExpected++;
break;
case FIN_WAIT1:
if (fLastByteAckd == fNextByteToWrite) {
// our FIN has been ACKd: go to TIME_WAIT
nextState = TIME_WAIT;
status = hash_remove(gConnectionHash, this);
if (status != B_OK)
return status;
gStackModule->set_timer(&fTimer, TCP_MAX_SEGMENT_LIFETIME);
} else
nextState = CLOSING;
fNextByteExpected++;
break;
default:
break;
}
}
flags |= TCP_FLG_ACK;
} else {
// out-of-order packet received. remind the other side of where we are
return _SendQueuedData(TCP_FLG_ACK, true);
}
break;
}
TRACE(("TCP %p.ReceiveData():Entering state %d\n", this, fState));
// state machine is done switching states and the data is good.
// put it in the receive buffer
// TODO: This isn't the most efficient way to do it, and will need to be changed
// to deal with Silly Window Syndrome
bufferHeader.Remove(headerLength);
if (buffer->size > 0) {
status = _EnqueueReceivedData(buffer, byteRcvd);
if (status != B_OK)
return status;
} else
gBufferModule->free(buffer);
if (fState != CLOSING && fState != LAST_ACK) {
status = _SendQueuedData(flags, false);
if (status != B_OK)
return status;
}
fState = nextState;
return B_OK;
}
status_t
TCPConnection::_Reset(uint32 sequence, uint32 acknowledge)
{
TRACE(("TCP:%p.Reset()\n", this));
net_buffer *reply = gBufferModule->create(512);
if (reply == NULL)
return B_NO_MEMORY;
gAddressModule->set_to((sockaddr *)&reply->source, (sockaddr *)&socket->address);
gAddressModule->set_to((sockaddr *)&reply->destination, (sockaddr *)&socket->peer);
status_t status = add_tcp_header(reply,
TCP_FLG_RST | (acknowledge == 0 ? 0 : TCP_FLG_ACK), sequence, acknowledge, 0);
if (status != B_OK) {
gBufferModule->free(reply);
return status;
}
TRACE(("TCP: Reset():Sending RST...\n"));
status = next->module->send_routed_data(next, fRoute, reply);
if (status != B_OK) {
// if sending failed, we stay responsible for the buffer
gBufferModule->free(reply);
}
return status;
}
/*!
Resends a sent segment (\a data) if the segment's ACK wasn't received
before the timeout (eg \a timer expired)
*/
void
TCPConnection::ResendSegment(struct net_timer *timer, void *data)
{
TRACE(("TCP:ResendSegment(%p)\n", data));
if (data == NULL)
return;
// TODO: implement me!
}
/*!
Sends a TCP packet with the specified \a flags. If there is any data in
the send buffer and \a empty is false, fEffectiveWindow bytes or less of it are sent as well.
Sequence and Acknowledgement numbers are filled in accordingly.
The fLock benaphore must be held before calling.
*/
status_t
TCPConnection::_SendQueuedData(uint16 flags, bool empty)
{
TRACE(("TCP:%p.SendQueuedData(%X,%s)\n", this, flags, empty ? "1" : "0"));
if (fRoute == NULL)
return B_ERROR;
net_buffer *buffer;
uint32 effectiveWindow = min_c(next->module->get_mtu(next,
(sockaddr *)&socket->address), fNextByteToWrite - fNextByteToSend);
if (empty || effectiveWindow == 0 || fSendBuffer == NULL || fSendBuffer->size == 0) {
buffer = gBufferModule->create(256);
TRACE(("TCP: Sending Buffer %p\n", buffer));
if (buffer == NULL)
return ENOBUFS;
} else {
buffer = fSendBuffer;
if (effectiveWindow == fSendBuffer->size)
fSendBuffer = NULL;
else
fSendBuffer = gBufferModule->split(fSendBuffer, effectiveWindow);
}
gAddressModule->set_to((sockaddr *)&buffer->source, (sockaddr *)&socket->address);
gAddressModule->set_to((sockaddr *)&buffer->destination, (sockaddr *)&socket->peer);
TRACE(("TCP:%p.SendQueuedData() to address %s\n", this,
AddressString(gDomain, (sockaddr *)&buffer->destination, true).Data()));
TRACE(("TCP:%p.SendQueuedData() from address %s\n", this,
AddressString(gDomain, (sockaddr *)&buffer->source, true).Data()));
uint16 advertisedWindow = TCP_MAX_RECV_BUF - (fNextByteExpected - fNextByteToRead);
uint32 size = buffer->size;
status_t status = add_tcp_header(buffer, flags, fNextByteToSend,
fNextByteExpected, advertisedWindow);
if (status != B_OK) {
gBufferModule->free(buffer);
return status;
}
// Only count 1 SYN, the 1 sent when transitioning from CLOSED or LISTEN
if (TCP_FLG_SYN & flags && (fState == CLOSED || fState == LISTEN))
fNextByteToSend++;
// Only count 1 FIN, the 1 sent when transitioning from ESTABLISHED, SYN_RCVD or CLOSE_WAIT
if (TCP_FLG_FIN & flags && (fState == SYN_RCVD || fState == ESTABLISHED || fState == CLOSE_WAIT))
fNextByteToSend++;
fNextByteToSend += size;
#if 0
tcp_segment *segment = new(std::nothrow)
tcp_segment(sequenceNum, 0, 2*fAvgRTT);
#endif
return next->module->send_routed_data(next, fRoute, buffer);
}
void
TCPConnection::_TimeWait(struct net_timer *timer, void *data)
{
}
int
TCPConnection::Compare(void *_connection, const void *_key)
{
const tcp_connection_key *key = (tcp_connection_key *)_key;
TCPConnection *connection= ((TCPConnection *)_connection);
if (gAddressModule->equal_addresses_and_ports(key->local,
(sockaddr *)&connection->socket->address)
&& gAddressModule->equal_addresses_and_ports(key->peer,
(sockaddr *)&connection->socket->peer))
return 0;
return 1;
}
uint32
TCPConnection::Hash(void *_connection, const void *_key, uint32 range)
{
if (_connection != NULL) {
TCPConnection *connection = (TCPConnection *)_connection;
return gAddressModule->hash_address_pair(
(sockaddr *)&connection->socket->address, (sockaddr *)&connection->socket->peer) % range;
}
const tcp_connection_key *key = (tcp_connection_key *)_key;
return gAddressModule->hash_address_pair(
key->local, key->peer) % range;
}

View File

@ -4,31 +4,18 @@
* *
* Authors: * Authors:
* Andrew Galante, haiku.galante@gmail.com * Andrew Galante, haiku.galante@gmail.com
* Axel Dörfler, axeld@pinc-software.de
*/ */
#ifndef TCP_CONNECTION_H
#define TCP_CONNECTION_H
// Initial estimate for packet round trip time (RTT)
#define TCP_INITIAL_RTT 120000000LL
// Estimate for Maximum segment lifetime in the internet #include "tcp.h"
#define TCP_MAX_SEGMENT_LIFETIME (2 * TCP_INITIAL_RTT)
// keep maximum buffer sizes < max net_buffer size for now #include <net_protocol.h>
#define TCP_MAX_SEND_BUF 1024 #include <net_stack.h>
#define TCP_MAX_RECV_BUF TCP_MAX_SEND_BUF
#define TCP_IS_GOOD_ACK(x) (fLastByteAckd > fNextByteToSend ? \ #include <stddef.h>
((x) >= fLastByteAckd || (x) <= fNextByteToSend) : \
((x) >= fLastByteAckd && (x) <= fNextByteToSend))
#define TCP_IS_GOOD_SEQ(x,y) (fNextByteToRead < fNextByteToRead + TCP_MAX_RECV_BUF ? \
(x) >= fNextByteToRead && (x) + (y) <= fNextByteToRead + TCP_MAX_RECV_BUF : \
((x) >= fNextByteToRead || (x) <= fNextByteToRead + TCP_MAX_RECV_BUF) && \
((x) + (y) >= fNextByteToRead || (x) + (y) <= fNextByteToRead + TCP_MAX_RECV_BUF))
typedef struct {
const sockaddr *local;
const sockaddr *peer;
} tcp_connection_key;
class TCPConnection : public net_protocol { class TCPConnection : public net_protocol {
@ -57,27 +44,13 @@ public:
static int Compare(void *_packet, const void *_key); static int Compare(void *_packet, const void *_key);
static uint32 Hash(void *_packet, const void *_key, uint32 range); static uint32 Hash(void *_packet, const void *_key, uint32 range);
static int32 HashOffset() { return offsetof(TCPConnection, fHashLink); } static int32 HashOffset() { return offsetof(TCPConnection, fHashLink); }
private: private:
status_t _SendQueuedData(uint16 flags, bool empty);
status_t _EnqueueReceivedData(net_buffer *buffer, uint32 sequenceNumber);
status_t _Reset(uint32 sequenceNum, uint32 acknowledgeNum);
class TCPSegment { static void _TimeWait(struct net_timer *timer, void *data);
public:
struct list_link link;
TCPSegment(net_buffer *buffer, uint32 sequenceNumber, bigtime_t timeout);
~TCPSegment();
net_buffer *fBuffer;
bigtime_t fTime;
uint32 fSequenceNumber;
net_timer fTimer;
bool fTimedOut;
};
status_t SendQueuedData(uint16 flags, bool empty);
status_t EnqueueReceivedData(net_buffer *buffer, uint32 sequenceNumber);
status_t Reset(uint32 sequenceNum, uint32 acknowledgeNum);
static void TimeWait(struct net_timer *timer, void *data);
uint32 fLastByteAckd; uint32 fLastByteAckd;
uint32 fNextByteToSend; uint32 fNextByteToSend;
@ -102,720 +75,7 @@ private:
net_timer fTimer; net_timer fTimer;
net_route *fRoute; net_route *fRoute;
// TODO: don't use a net_route, but a net_route_info!!!
}; };
#endif // TCP_CONNECTION_H
TCPConnection::TCPSegment::TCPSegment(net_buffer *buffer, uint32 sequenceNumber, bigtime_t timeout)
:
fBuffer(buffer),
fTime(system_time()),
fSequenceNumber(sequenceNumber),
fTimedOut(false)
{
if (timeout > 0) {
sStackModule->init_timer(&fTimer, &TCPConnection::ResendSegment, this);
sStackModule->set_timer(&fTimer, timeout);
}
}
TCPConnection::TCPSegment::~TCPSegment()
{
sStackModule->set_timer(&fTimer, -1);
}
TCPConnection::TCPConnection(net_socket *socket)
:
fLastByteAckd(0), //system_time()),
fNextByteToSend(fLastByteAckd),
fNextByteToWrite(fLastByteAckd + 1),
fNextByteToRead(0),
fNextByteExpected(0),
fLastByteReceived(0),
fAvgRTT(TCP_INITIAL_RTT),
fSendBuffer(NULL),
fReceiveBuffer(NULL),
fState(CLOSED),
fError(B_OK),
fRoute(NULL)
{
benaphore_init(&fLock, "TCPConnection");
sStackModule->init_timer(&fTimer, TimeWait, this);
list_init(&fReorderQueue);
list_init(&fWaitQueue);
}
TCPConnection::~TCPConnection()
{
benaphore_destroy(&fLock);
}
status_t
TCPConnection::Open()
{
TRACE(("%p.Open()\n", this));
if (sAddressModule == NULL)
return B_ERROR;
TRACE(("TCP: Open(): Using Address Module %p\n", sAddressModule));
BenaphoreLocker lock(&fLock);
sAddressModule->set_to_empty_address((sockaddr *)&socket->address);
sAddressModule->set_port((sockaddr *)&socket->address, 0);
sAddressModule->set_to_empty_address((sockaddr *)&socket->peer);
sAddressModule->set_port((sockaddr *)&socket->peer, 0);
return B_OK;
}
status_t
TCPConnection::Close()
{
BenaphoreLocker lock(&fLock);
TRACE(("TCP:%p.Close()\n", this));
if (fState == SYN_SENT || fState == LISTEN) {
fState = CLOSED;
return B_OK;
}
tcp_state nextState = CLOSED;
if (fState == SYN_RCVD || fState == ESTABLISHED)
nextState = FIN_WAIT1;
if (fState == CLOSE_WAIT)
nextState = LAST_ACK;
status_t status = SendQueuedData(TCP_FLG_FIN | TCP_FLG_ACK, false);
if (status != B_OK)
return status;
fState = nextState;
TRACE(("TCP: %p.Close(): Entering state %d\n", this, fState));
//do i need to wait until fState returns to CLOSED?
return B_OK;
}
status_t
TCPConnection::Free()
{
TRACE(("TCP:%p.Free()\n", this));
BenaphoreLocker hashLock(&sTCPLock);
BenaphoreLocker lock(&fLock);
tcp_connection_key key;
key.local = (sockaddr *)&socket->address;
key.peer = (sockaddr *)&socket->peer;
if (hash_lookup(sTCPHash, &key) != NULL) {
return hash_remove(sTCPHash, (void *)this);
}
return B_OK;
}
/*!
Creates and sends a SYN packet to /a address
*/
status_t
TCPConnection::Connect(const struct sockaddr *address)
{
TRACE(("TCP:%p.Connect() on address %s\n", this,
AddressString(sDomain, address, true).Data()));
if (address->sa_family != AF_INET)
return EAFNOSUPPORT;
benaphore_lock(&sTCPLock); // want to release lock later, so no autolock
BenaphoreLocker lock(&fLock);
// Can only call Connect from CLOSED or LISTEN states
// otherwise connection is considered already connected
if (fState != CLOSED && fState != LISTEN) {
benaphore_unlock(&sTCPLock);
return EISCONN;
}
TRACE(("TCP: Connect(): in state %d\n", fState));
// get a net_route if there isn't one
if (fRoute == NULL) {
fRoute = sDatalinkModule->get_route(sDomain, (sockaddr *)address);
TRACE(("TCP: Connect(): Using Route %p\n", fRoute));
if (fRoute == NULL) {
benaphore_unlock(&sTCPLock);
return ENETUNREACH;
}
}
// need to associate this connection with a real address, not INADDR_ANY
if (sAddressModule->is_empty_address((sockaddr *)&socket->address, false)) {
TRACE(("TCP: Connect(): Local Address is INADDR_ANY\n"));
sAddressModule->set_to((sockaddr *)&socket->address, (sockaddr *)fRoute->interface->address);
// since most stacks terminate connections from port 0
// use port 40000 for now. This should be moved to Bind(), and Bind() called before Connect().
sAddressModule->set_port((sockaddr *)&socket->address, htons(40000));
}
// make sure connection does not already exist
tcp_connection_key key;
key.local = (sockaddr *)&socket->address;
key.peer = address;
if (hash_lookup(sTCPHash, &key) != NULL) {
benaphore_unlock(&sTCPLock);
return EADDRINUSE;
}
TRACE(("TCP: Connect(): connecting...\n"));
status_t status;
sAddressModule->set_to((sockaddr *)&socket->peer, address);
status = hash_insert(sTCPHash, (void *)this);
if (status != B_OK) {
TRACE(("TCP: Connect(): Error inserting connection into hash!\n"));
benaphore_unlock(&sTCPLock);
return status;
}
// done manipulating the hash, release the lock
benaphore_unlock(&sTCPLock);
TRACE(("TCP: Connect(): starting 3-way handshake...\n"));
// send SYN
status = SendQueuedData(TCP_FLG_SYN, false);
if (status != B_OK)
return status;
fState = SYN_SENT;
// TODO: Should Connect() not return until 3-way handshake is complete?
TRACE(("TCP: Connect(): Connection complete\n"));
return B_OK;
}
status_t
TCPConnection::Accept(struct net_socket **_acceptedSocket)
{
TRACE(("TCP:%p.Accept()\n", this));
return B_ERROR;
}
status_t
TCPConnection::Bind(sockaddr *address)
{
TRACE(("TCP:%p.Bind() on address %s\n", this,
AddressString(sDomain, address, true).Data()));
if (address->sa_family != AF_INET)
return EAFNOSUPPORT;
BenaphoreLocker hashLock(&sTCPLock);
BenaphoreLocker lock(&fLock);
// let IP check whether there is an interface that supports the given address:
status_t status = next->module->bind(next, address);
if (status < B_OK)
return status;
sAddressModule->set_to((sockaddr *)&socket->address, address);
// for now, leave port=0. TCP should still work 1 connection at a time
if (0) { //sAddressModule->get_port((sockaddr *)&socket->address) == 0) {
//assign ephemeral port
} else {
//TODO:Check for Socket flags
tcp_connection_key key;
key.peer = (sockaddr *)&socket->peer;
key.local = (sockaddr *)&socket->address;
if (hash_lookup(sTCPHash, &key) == NULL) {
hash_insert(sTCPHash, (void *)this);
} else
return EADDRINUSE;
}
return B_OK;
}
status_t
TCPConnection::Unbind(struct sockaddr *address)
{
TRACE(("TCP:%p.Unbind()\n", this ));
BenaphoreLocker hashLock(&sTCPLock);
BenaphoreLocker lock(&fLock);
status_t status = hash_remove(sTCPHash, (void *)this);
if (status != B_OK)
return status;
sAddressModule->set_to_empty_address((sockaddr *)&socket->address);
sAddressModule->set_port((sockaddr *)&socket->address, 0);
return B_OK;
}
status_t
TCPConnection::Listen(int count)
{
TRACE(("TCP:%p.Listen()\n", this));
BenaphoreLocker lock(&fLock);
if (fState != CLOSED)
return B_ERROR;
fState = LISTEN;
return B_OK;
}
status_t
TCPConnection::Shutdown(int direction)
{
TRACE(("TCP:%p.Shutdown()\n", this));
return B_ERROR;
}
/*!
Puts data contained in \a buffer into send buffer
*/
status_t
TCPConnection::SendData(net_buffer *buffer)
{
TRACE(("TCP:%p.SendData()\n", this));
size_t bufferSize = buffer->size;
BenaphoreLocker lock(&fLock);
if (fSendBuffer != NULL) {
status_t status = sBufferModule->merge(fSendBuffer, buffer, true);
if (status != B_OK)
return status;
} else
fSendBuffer = buffer;
fNextByteToWrite += bufferSize;
return SendQueuedData(TCP_FLG_ACK, false);
}
status_t
TCPConnection::SendRoutedData(net_route *route, net_buffer *buffer)
{
TRACE(("TCP:%p.SendRoutedData()\n", this));
{
BenaphoreLocker lock(&fLock);
fRoute = route;
}
return SendData(buffer);
}
size_t
TCPConnection::SendAvailable()
{
TRACE(("TCP:%p.SendAvailable()\n", this));
BenaphoreLocker lock(&fLock);
if (fSendBuffer != NULL)
return TCP_MAX_SEND_BUF - fSendBuffer->size;
return TCP_MAX_SEND_BUF;
}
status_t
TCPConnection::ReadData(size_t numBytes, uint32 flags, net_buffer** _buffer)
{
TRACE(("TCP:%p.ReadData()\n", this));
BenaphoreLocker lock(&fLock);
// must be in a synchronous state
if (fState != ESTABLISHED || fState != FIN_WAIT1 || fState != FIN_WAIT2) {
// is this correct semantics?
dprintf(" TCP state = %d\n", fState);
return B_ERROR;
}
dprintf(" TCP error = %ld\n", fError);
if (fError != B_OK)
return fError;
if (fReceiveBuffer->size < numBytes)
numBytes = fReceiveBuffer->size;
*_buffer = sBufferModule->split(fReceiveBuffer, numBytes);
if (*_buffer == NULL)
return B_NO_MEMORY;
return B_OK;
}
size_t
TCPConnection::ReadAvailable()
{
TRACE(("TCP:%p.ReadAvailable()\n", this));
BenaphoreLocker lock(&fLock);
if (fReceiveBuffer != NULL)
return fReceiveBuffer->size;
return 0;
}
status_t
TCPConnection::EnqueueReceivedData(net_buffer *buffer, uint32 sequenceNumber)
{
TRACE(("TCP:%p.EnqueueReceivedData(%p, %lu)\n", this, buffer, sequenceNumber));
status_t status;
if (sequenceNumber == fNextByteExpected) {
// first check if the received buffer meets up with the first
// segment in the ReorderQueue
TCPSegment *next;
while ((next = (TCPSegment *)list_get_first_item(&fReorderQueue)) != NULL) {
if (sequenceNumber + buffer->size >= next->fSequenceNumber) {
if (sequenceNumber + buffer->size > next->fSequenceNumber) {
status = sBufferModule->trim(buffer, sequenceNumber - next->fSequenceNumber);
if (status != B_OK)
return status;
}
status = sBufferModule->merge(buffer, next->fBuffer, true);
if (status != B_OK)
return status;
list_remove_item(&fReorderQueue, next);
delete next;
} else
break;
}
fNextByteExpected += buffer->size;
if (fReceiveBuffer == NULL)
fReceiveBuffer = buffer;
else {
status = sBufferModule->merge(fReceiveBuffer, buffer, true);
if (status < B_OK) {
fNextByteExpected -= buffer->size;
return status;
}
}
} else {
// add this buffer into the ReorderQueue in the correct place
// creating a new TCPSegment if necessary
TCPSegment *next = NULL;
do {
next = (TCPSegment *)list_get_next_item(&fReorderQueue, next);
if (next != NULL && next->fSequenceNumber < sequenceNumber)
continue;
if (next != NULL && sequenceNumber + buffer->size >= next->fSequenceNumber) {
// merge the new buffer with the next buffer
if (sequenceNumber + buffer->size > next->fSequenceNumber) {
status = sBufferModule->trim(buffer, sequenceNumber - next->fSequenceNumber);
if (status != B_OK)
return status;
}
status = sBufferModule->merge(buffer, next->fBuffer, true);
if (status != B_OK)
return status;
next->fBuffer = buffer;
next->fSequenceNumber = sequenceNumber;
break;
}
TCPSegment *segment = new(std::nothrow) TCPSegment(buffer, sequenceNumber, -1);
if (next == NULL)
list_add_item(&fReorderQueue, segment);
else
list_insert_item_before(&fReorderQueue, next, segment);
} while (next != NULL);
}
return B_OK;
}
status_t
TCPConnection::ReceiveData(net_buffer *buffer)
{
BenaphoreLocker lock(&fLock);
TRACE(("TCP:%p.ReceiveData()\n", this));
NetBufferHeader<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() < B_OK)
return bufferHeader.Status();
tcp_header &header = bufferHeader.Data();
uint16 flags = 0x0;
tcp_state nextState = fState;
status_t status = B_OK;
uint32 byteAckd = ntohl(header.acknowledge_num);
uint32 byteRcvd = ntohl(header.sequence_num);
uint32 headerLength = (uint32)header.header_length << 2;
uint32 payloadLength = buffer->size - headerLength;
TRACE(("TCP: ReceiveData(): Connection in state %d received packet %p with flags %X!\n", fState, buffer, header.flags));
switch (fState) {
case CLOSED:
case TIME_WAIT:
sBufferModule->free(buffer);
if (header.flags & TCP_FLG_ACK)
return Reset(byteAckd, 0);
return Reset(0, byteRcvd + payloadLength);
case LISTEN:
// if packet is SYN, spawn new TCPConnection in SYN_RCVD state
// and add it to the Connection Queue. The new TCPConnection
// must continue the handshake by replying with SYN+ACK. Any
// data in the packet must go into the new TCPConnection's receive
// buffer.
// Otherwise, RST+ACK is sent.
// The current TCPConnection always remains in LISTEN state.
return B_ERROR;
case SYN_SENT:
if (header.flags & TCP_FLG_RST) {
fError = ECONNREFUSED;
fState = CLOSED;
return B_ERROR;
}
if (header.flags & TCP_FLG_ACK && !TCP_IS_GOOD_ACK(byteAckd))
return Reset(byteAckd, 0);
if (header.flags & TCP_FLG_SYN) {
fNextByteToRead = fNextByteExpected = ntohl(header.sequence_num) + 1;
flags |= TCP_FLG_ACK;
fLastByteAckd = byteAckd;
// cancel resend of this segment
if (header.flags & TCP_FLG_ACK)
nextState = ESTABLISHED;
else {
nextState = SYN_RCVD;
flags |= TCP_FLG_SYN;
}
}
break;
case SYN_RCVD:
if (header.flags & TCP_FLG_ACK && TCP_IS_GOOD_ACK(byteAckd))
fState = ESTABLISHED;
else
Reset(byteAckd, 0);
break;
default:
// In a synchronized state.
// first check that the received sequence number is good
if (TCP_IS_GOOD_SEQ(byteRcvd, payloadLength)) {
// If a valid RST was received, terminate the connection.
if (header.flags & TCP_FLG_RST) {
fError = ECONNREFUSED;
fState = CLOSED;
return B_ERROR;
}
if (header.flags & TCP_FLG_ACK && TCP_IS_GOOD_ACK(byteAckd) ) {
fLastByteAckd = byteAckd;
if (fLastByteAckd == fNextByteToWrite) {
if (fState == LAST_ACK ) {
nextState = CLOSED;
status = hash_remove(sTCPHash, this);
if (status != B_OK)
return status;
}
if (fState == CLOSING) {
nextState = TIME_WAIT;
status = hash_remove(sTCPHash, this);
if (status != B_OK)
return status;
}
if (fState == FIN_WAIT1) {
nextState = FIN_WAIT2;
}
}
}
if (header.flags & TCP_FLG_FIN) {
// other side is closing connection. change states
switch (fState) {
case ESTABLISHED:
nextState = CLOSE_WAIT;
fNextByteExpected++;
break;
case FIN_WAIT2:
nextState = TIME_WAIT;
fNextByteExpected++;
break;
case FIN_WAIT1:
if (fLastByteAckd == fNextByteToWrite) {
// our FIN has been ACKd: go to TIME_WAIT
nextState = TIME_WAIT;
status = hash_remove(sTCPHash, this);
if (status != B_OK)
return status;
sStackModule->set_timer(&fTimer, TCP_MAX_SEGMENT_LIFETIME);
} else
nextState = CLOSING;
fNextByteExpected++;
break;
default:
break;
}
}
flags |= TCP_FLG_ACK;
} else {
// out-of-order packet received. remind the other side of where we are
return SendQueuedData(TCP_FLG_ACK, true);
}
break;
}
TRACE(("TCP %p.ReceiveData():Entering state %d\n", this, fState));
// state machine is done switching states and the data is good.
// put it in the receive buffer
// TODO: This isn't the most efficient way to do it, and will need to be changed
// to deal with Silly Window Syndrome
bufferHeader.Remove(headerLength);
if (buffer->size > 0) {
status = EnqueueReceivedData(buffer, byteRcvd);
if (status != B_OK)
return status;
} else
sBufferModule->free(buffer);
if (fState != CLOSING && fState != LAST_ACK) {
status = SendQueuedData(flags, false);
if (status != B_OK)
return status;
}
fState = nextState;
return B_OK;
}
status_t
TCPConnection::Reset(uint32 sequenceNum, uint32 acknowledgeNum)
{
TRACE(("TCP:%p.Reset()\n", this));
net_buffer *reply_buf = sBufferModule->create(512);
sAddressModule->set_to((sockaddr *)&reply_buf->source, (sockaddr *)&socket->address);
sAddressModule->set_to((sockaddr *)&reply_buf->destination, (sockaddr *)&socket->peer);
uint16 flags = TCP_FLG_RST | acknowledgeNum == 0 ? 0 : TCP_FLG_ACK;
status_t status = tcp_segment(reply_buf, flags , sequenceNum, acknowledgeNum, 0);
if (status != B_OK) {
sBufferModule->free(reply_buf);
return status;
}
TRACE(("TCP: Reset():Sending RST...\n"));
status = next->module->send_routed_data(next, fRoute, reply_buf);
if (status !=B_OK)
sBufferModule->free(reply_buf);
return status;
}
/*!
Resends a sent segment (\a data) if the segment's ACK wasn't received
before the timeout (eg \a timer expired)
*/
void
TCPConnection::ResendSegment(struct net_timer *timer, void *data)
{
TRACE(("TCP:ResendSegment(%p)\n", data));
if (data == NULL)
return;
}
/*!
Sends a TCP packet with the specified \a flags. If there is any data in
the send buffer and \a empty is false, fEffectiveWindow bytes or less of it are sent as well.
Sequence and Acknowledgement numbers are filled in accordingly.
The fLock benaphore must be held before calling.
*/
status_t
TCPConnection::SendQueuedData(uint16 flags, bool empty)
{
TRACE(("TCP:%p.SendQueuedData(%X,%s)\n", this, flags, empty ? "1" : "0"));
if (fRoute == NULL)
return B_ERROR;
net_buffer *buffer;
uint32 effectiveWindow = min_c(tcp_get_mtu(this, (sockaddr *)&socket->address), fNextByteToWrite - fNextByteToSend);
if (empty || effectiveWindow == 0 || fSendBuffer == NULL || fSendBuffer->size == 0) {
buffer = sBufferModule->create(256);
TRACE(("TCP: Sending Buffer %p\n", buffer));
if (buffer == NULL)
return ENOBUFS;
} else {
buffer = fSendBuffer;
if (effectiveWindow == fSendBuffer->size)
fSendBuffer = NULL;
else
fSendBuffer = sBufferModule->split(fSendBuffer, effectiveWindow);
}
sAddressModule->set_to((sockaddr *)&buffer->source, (sockaddr *)&socket->address);
sAddressModule->set_to((sockaddr *)&buffer->destination, (sockaddr *)&socket->peer);
TRACE(("TCP:%p.SendQueuedData() to address %s\n", this,
AddressString(sDomain, (sockaddr *)&buffer->destination, true).Data()));
TRACE(("TCP:%p.SendQueuedData() from address %s\n", this,
AddressString(sDomain, (sockaddr *)&buffer->source, true).Data()));
uint16 advWin = TCP_MAX_RECV_BUF - (fNextByteExpected - fNextByteToRead);
uint32 size = buffer->size;
status_t status = tcp_segment(buffer, flags, fNextByteToSend, fNextByteExpected, advWin);
if (status != B_OK) {
sBufferModule->free(buffer);
return status;
}
// Only count 1 SYN, the 1 sent when transitioning from CLOSED or LISTEN
if (TCP_FLG_SYN & flags && (fState == CLOSED || fState == LISTEN))
fNextByteToSend++;
// Only count 1 FIN, the 1 sent when transitioning from ESTABLISHED, SYN_RCVD or CLOSE_WAIT
if (TCP_FLG_FIN & flags && (fState == SYN_RCVD || fState == ESTABLISHED || fState == CLOSE_WAIT))
fNextByteToSend++;
fNextByteToSend += size;
#if 0
TCPSegment *segment = new(std::nothrow)
TCPSegment(sequenceNum, 0, 2*fAvgRTT);
#endif
return next->module->send_routed_data(next, fRoute, buffer);
}
void
TCPConnection::TimeWait(struct net_timer *timer, void *data)
{
}
int
TCPConnection::Compare(void *_connection, const void *_key)
{
const tcp_connection_key *key = (tcp_connection_key *)_key;
TCPConnection *connection= ((TCPConnection *)_connection);
if (sAddressModule->equal_addresses_and_ports(key->local,
(sockaddr *)&connection->socket->address)
&& sAddressModule->equal_addresses_and_ports(key->peer,
(sockaddr *)&connection->socket->peer))
return 0;
return 1;
}
uint32
TCPConnection::Hash(void *_connection, const void *_key, uint32 range)
{
if (_connection != NULL) {
TCPConnection *connection = (TCPConnection *)_connection;
return sAddressModule->hash_address_pair(
(sockaddr *)&connection->socket->address, (sockaddr *)&connection->socket->peer) % range;
}
const tcp_connection_key *key = (tcp_connection_key *)_key;
return sAddressModule->hash_address_pair(
key->local, key->peer) % range;
}

View File

@ -4,13 +4,13 @@
* *
* Authors: * Authors:
* Axel Dörfler, axeld@pinc-software.de * Axel Dörfler, axeld@pinc-software.de
* Andrew Galante, haiku.galante@gmail.com
*/ */
#include <net_buffer.h> #include "TCPConnection.h"
#include <net_datalink.h>
#include <net_protocol.h> #include <net_protocol.h>
#include <net_stack.h>
#include <KernelExport.h> #include <KernelExport.h>
#include <util/list.h> #include <util/list.h>
@ -23,7 +23,6 @@
#include <lock.h> #include <lock.h>
#include <util/AutoLock.h> #include <util/AutoLock.h>
#include <util/khash.h>
#include <NetBufferUtilities.h> #include <NetBufferUtilities.h>
#include <NetUtilities.h> #include <NetUtilities.h>
@ -40,51 +39,112 @@
#define MAX_HASH_TCP 64 #define MAX_HASH_TCP 64
static net_domain *sDomain;
static net_address_module_info *sAddressModule;
net_buffer_module_info *sBufferModule;
static net_datalink_module_info *sDatalinkModule;
static net_stack_module_info *sStackModule;
static hash_table *sTCPHash;
static benaphore sTCPLock;
status_t net_domain *gDomain;
tcp_segment(net_buffer *buffer, uint16 flags, uint32 seq, uint32 ack, uint16 adv_win); net_address_module_info *gAddressModule;
net_buffer_module_info *gBufferModule;
size_t net_datalink_module_info *gDatalinkModule;
tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address); net_stack_module_info *gStackModule;
hash_table *gConnectionHash;
#include "tcp.h" benaphore gConnectionLock;
#include "TCPConnection.h"
#ifdef TRACE_TCP #ifdef TRACE_TCP
# define DUMP_TCP_HASH tcp_dump_hash() # define DUMP_TCP_HASH tcp_dump_hash()
// Dumps the TCP Connection hash. sTCPLock must NOT be held when calling // Dumps the TCP Connection hash. gConnectionLock must NOT be held when calling
void void
tcp_dump_hash(){ tcp_dump_hash()
BenaphoreLocker lock(&sTCPLock); {
if (sDomain == NULL) { BenaphoreLocker lock(&gConnectionLock);
if (gDomain == NULL) {
TRACE(("Unable to dump TCP Connections!\n")); TRACE(("Unable to dump TCP Connections!\n"));
return; return;
} }
struct hash_iterator iterator; struct hash_iterator iterator;
hash_open(sTCPHash, &iterator); hash_open(gConnectionHash, &iterator);
TCPConnection *connection; TCPConnection *connection;
hash_rewind(sTCPHash, &iterator); hash_rewind(gConnectionHash, &iterator);
TRACE(("Active TCP Connections:\n")); TRACE(("Active TCP Connections:\n"));
while ((connection = (TCPConnection *)hash_next(sTCPHash, &iterator)) != NULL) { while ((connection = (TCPConnection *)hash_next(gConnectionHash, &iterator)) != NULL) {
TRACE((" TCPConnection %p: %s, %s\n", connection, TRACE((" TCPConnection %p: %s, %s\n", connection,
AddressString(sDomain, (sockaddr *)&connection->socket->address, true).Data(), AddressString(gDomain, (sockaddr *)&connection->socket->address, true).Data(),
AddressString(sDomain, (sockaddr *)&connection->socket->peer, true).Data())); AddressString(gDomain, (sockaddr *)&connection->socket->peer, true).Data()));
} }
hash_close(sTCPHash, &iterator, false); hash_close(gConnectionHash, &iterator, false);
} }
#else #else
# define DUMP_TCP_HASH 0 # define DUMP_TCP_HASH 0
#endif #endif
status_t
set_domain(net_interface *interface = NULL)
{
if (gDomain == NULL) {
// domain and address module are not known yet, we copy them from
// the buffer's interface (if any):
if (interface == NULL || interface->domain == NULL)
gDomain = gStackModule->get_domain(AF_INET);
else
gDomain = interface->domain;
if (gDomain == NULL) {
// this shouldn't occur, of course, but who knows...
return B_BAD_VALUE;
}
gAddressModule = gDomain->address_module;
}
return B_OK;
}
/*!
Constructs a TCP header on \a buffer with the specified values
for \a flags, \a seq \a ack and \a advertisedWindow.
*/
status_t
add_tcp_header(net_buffer *buffer, uint16 flags, uint32 sequence, uint32 ack,
uint16 advertisedWindow)
{
buffer->protocol = IPPROTO_TCP;
NetBufferPrepend<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() != B_OK)
return bufferHeader.Status();
tcp_header &header = bufferHeader.Data();
header.source_port = gAddressModule->get_port((sockaddr *)&buffer->source);
header.destination_port = gAddressModule->get_port((sockaddr *)&buffer->destination);
header.sequence_num = htonl(sequence);
header.acknowledge_num = htonl(ack);
header.reserved = 0;
header.header_length = 5;
// currently no options supported
header.flags = (uint8)flags;
header.advertised_window = htons(advertisedWindow);
header.checksum = 0;
header.urgent_ptr = 0;
// urgent pointer not supported
// compute and store checksum
Checksum checksum;
gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source);
gAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination);
checksum
<< (uint16)htons(IPPROTO_TCP)
<< (uint16)htons(buffer->size)
<< Checksum::BufferHelper(buffer, gBufferModule);
header.checksum = checksum;
TRACE(("TCP: Checksum for segment %p is %X\n", buffer, header.checksum));
return B_OK;
}
// #pragma mark - protocol API
net_protocol * net_protocol *
tcp_init_protocol(net_socket *socket) tcp_init_protocol(net_socket *socket)
{ {
@ -109,10 +169,9 @@ tcp_uninit_protocol(net_protocol *protocol)
status_t status_t
tcp_open(net_protocol *protocol) tcp_open(net_protocol *protocol)
{ {
if (!sDomain) if (gDomain == NULL && set_domain() != B_OK)
sDomain = sStackModule->get_domain(AF_INET); return B_ERROR;
if (!sAddressModule)
sAddressModule = sDomain->address_module;
DUMP_TCP_HASH; DUMP_TCP_HASH;
return ((TCPConnection *)protocol)->Open(); return ((TCPConnection *)protocol)->Open();
@ -240,63 +299,13 @@ tcp_get_mtu(net_protocol *protocol, const struct sockaddr *address)
} }
/*!
Constructs a TCP header on \a buffer with the specified values
for \a flags, \a seq \a ack and \a adv_win.
*/
status_t
tcp_segment(net_buffer *buffer, uint16 flags, uint32 seq, uint32 ack, uint16 adv_win)
{
buffer->protocol = IPPROTO_TCP;
NetBufferPrepend<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() != B_OK)
return bufferHeader.Status();
tcp_header &header = bufferHeader.Data();
header.source_port = sAddressModule->get_port((sockaddr *)&buffer->source);
header.destination_port = sAddressModule->get_port((sockaddr *)&buffer->destination);
header.sequence_num = htonl(seq);
header.acknowledge_num = htonl(ack);
header.reserved = 0;
header.header_length = 5;// currently no options supported
header.flags = (uint8)flags;
header.advertised_window = htons(adv_win);
header.checksum = 0;
header.urgent_ptr = 0;// urgent pointer not supported
// compute and store checksum
Checksum checksum;
sAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->source);
sAddressModule->checksum_address(&checksum, (sockaddr *)&buffer->destination);
checksum
<< (uint16)htons(IPPROTO_TCP)
<< (uint16)htons(buffer->size)
<< Checksum::BufferHelper(buffer, sBufferModule);
header.checksum = checksum;
TRACE(("TCP: Checksum for segment %p is %X\n", buffer, header.checksum));
return B_OK;
}
status_t status_t
tcp_receive_data(net_buffer *buffer) tcp_receive_data(net_buffer *buffer)
{ {
TRACE(("TCP: Received buffer %p\n", buffer)); TRACE(("TCP: Received buffer %p\n", buffer));
if (!sDomain) {
// domain and address module are not known yet, we copy them from if (gDomain == NULL && set_domain(buffer->interface) != B_OK)
// the buffer's interface (if any): return B_ERROR;
if (buffer->interface == NULL || buffer->interface->domain == NULL)
sDomain = sStackModule->get_domain(AF_INET);
else
sDomain = buffer->interface->domain;
if (sDomain == NULL) {
// this shouldn't occur, of course, but who knows...
return B_BAD_VALUE;
}
sAddressModule = sDomain->address_module;
}
NetBufferHeader<tcp_header> bufferHeader(buffer); NetBufferHeader<tcp_header> bufferHeader(buffer);
if (bufferHeader.Status() < B_OK) if (bufferHeader.Status() < B_OK)
@ -310,13 +319,13 @@ tcp_receive_data(net_buffer *buffer)
// TODO: check TCP Checksum // TODO: check TCP Checksum
sAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port); gAddressModule->set_port((struct sockaddr *)&buffer->source, header.source_port);
sAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port); gAddressModule->set_port((struct sockaddr *)&buffer->destination, header.destination_port);
DUMP_TCP_HASH; DUMP_TCP_HASH;
BenaphoreLocker hashLock(&sTCPLock); BenaphoreLocker hashLock(&gConnectionLock);
TCPConnection *connection = (TCPConnection *)hash_lookup(sTCPHash, &key); TCPConnection *connection = (TCPConnection *)hash_lookup(gConnectionHash, &key);
TRACE(("TCP: Received packet corresponds to connection %p\n", connection)); TRACE(("TCP: Received packet corresponds to connection %p\n", connection));
if (connection != NULL){ if (connection != NULL){
return connection->ReceiveData(buffer); return connection->ReceiveData(buffer);
@ -330,31 +339,37 @@ tcp_receive_data(net_buffer *buffer)
// If no connection exists (and RST is not set) send RST // If no connection exists (and RST is not set) send RST
if (!(header.flags & TCP_FLG_RST)) { if (!(header.flags & TCP_FLG_RST)) {
TRACE(("TCP: Connection does not exist!\n")); TRACE(("TCP: Connection does not exist!\n"));
net_buffer *reply_buf = sBufferModule->create(512); net_buffer *reply = gBufferModule->create(512);
sAddressModule->set_to((sockaddr *)&reply_buf->source, (sockaddr *)&buffer->destination); if (reply == NULL)
sAddressModule->set_to((sockaddr *)&reply_buf->destination, (sockaddr *)&buffer->source); return B_NO_MEMORY;
uint32 sequenceNum, acknowledgeNum; gAddressModule->set_to((sockaddr *)&reply->source,
(sockaddr *)&buffer->destination);
gAddressModule->set_to((sockaddr *)&reply->destination,
(sockaddr *)&buffer->source);
uint32 sequence, acknowledge;
uint16 flags; uint16 flags;
if (header.flags & TCP_FLG_ACK) { if (header.flags & TCP_FLG_ACK) {
sequenceNum = ntohl(header.acknowledge_num); sequence = ntohl(header.acknowledge_num);
acknowledgeNum = 0; acknowledge = 0;
flags = TCP_FLG_RST; flags = TCP_FLG_RST;
} else { } else {
sequenceNum = 0; sequence = 0;
acknowledgeNum = ntohl(header.sequence_num) + 1 + buffer->size - ((uint32)header.header_length<<2); acknowledge = ntohl(header.sequence_num) + 1
+ buffer->size - ((uint32)header.header_length << 2);
flags = TCP_FLG_RST | TCP_FLG_ACK; flags = TCP_FLG_RST | TCP_FLG_ACK;
} }
status_t status = tcp_segment(reply_buf, flags, sequenceNum, acknowledgeNum, 0); status_t status = add_tcp_header(reply, flags, sequence, acknowledge, 0);
if (status != B_OK) {
sBufferModule->free(reply_buf); if (status == B_OK) {
return status;
}
TRACE(("TCP: Sending RST...\n")); TRACE(("TCP: Sending RST...\n"));
status = sDomain->module->send_data(NULL, reply_buf); status = gDomain->module->send_data(NULL, reply);
}
if (status != B_OK) { if (status != B_OK) {
sBufferModule->free(reply_buf); gBufferModule->free(reply);
return status; return status;
} }
} }
@ -386,43 +401,43 @@ tcp_init()
{ {
status_t status; status_t status;
sDomain = NULL; gDomain = NULL;
sAddressModule = NULL; gAddressModule = NULL;
status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule); status = get_module(NET_STACK_MODULE_NAME, (module_info **)&gStackModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
goto err1; goto err1;
status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule); status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&gDatalinkModule);
if (status < B_OK) if (status < B_OK)
goto err2; goto err2;
sTCPHash = hash_init(MAX_HASH_TCP, TCPConnection::HashOffset(), gConnectionHash = hash_init(MAX_HASH_TCP, TCPConnection::HashOffset(),
&TCPConnection::Compare, &TCPConnection::Hash); &TCPConnection::Compare, &TCPConnection::Hash);
if (sTCPHash == NULL) if (gConnectionHash == NULL)
goto err3; goto err3;
status = benaphore_init(&sTCPLock, "TCP Hash Lock"); status = benaphore_init(&gConnectionLock, "TCP Hash Lock");
if (status < B_OK) if (status < B_OK)
goto err4; goto err4;
status = sStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_IP, status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_IP,
"network/protocols/tcp/v1", "network/protocols/tcp/v1",
"network/protocols/ipv4/v1", "network/protocols/ipv4/v1",
NULL); NULL);
if (status < B_OK) if (status < B_OK)
goto err5; goto err5;
status = sStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP, status = gStackModule->register_domain_protocols(AF_INET, SOCK_STREAM, IPPROTO_TCP,
"network/protocols/tcp/v1", "network/protocols/tcp/v1",
"network/protocols/ipv4/v1", "network/protocols/ipv4/v1",
NULL); NULL);
if (status < B_OK) if (status < B_OK)
goto err5; goto err5;
status = sStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP, status = gStackModule->register_domain_receiving_protocol(AF_INET, IPPROTO_TCP,
"network/protocols/tcp/v1"); "network/protocols/tcp/v1");
if (status < B_OK) if (status < B_OK)
goto err5; goto err5;
@ -430,9 +445,9 @@ tcp_init()
return B_OK; return B_OK;
err5: err5:
benaphore_destroy(&sTCPLock); benaphore_destroy(&gConnectionLock);
err4: err4:
hash_uninit(sTCPHash); hash_uninit(gConnectionHash);
err3: err3:
put_module(NET_DATALINK_MODULE_NAME); put_module(NET_DATALINK_MODULE_NAME);
err2: err2:
@ -448,8 +463,8 @@ err1:
static status_t static status_t
tcp_uninit() tcp_uninit()
{ {
benaphore_destroy(&sTCPLock); benaphore_destroy(&gConnectionLock);
hash_uninit(sTCPHash); hash_uninit(gConnectionHash);
put_module(NET_DATALINK_MODULE_NAME); put_module(NET_DATALINK_MODULE_NAME);
put_module(NET_BUFFER_MODULE_NAME); put_module(NET_BUFFER_MODULE_NAME);
put_module(NET_STACK_MODULE_NAME); put_module(NET_STACK_MODULE_NAME);

View File

@ -5,8 +5,18 @@
* Authors: * Authors:
* Andrew Galante, haiku.galante@gmail.com * Andrew Galante, haiku.galante@gmail.com
*/ */
#ifndef TCP_H
#define TCP_H
#include <net_buffer.h>
#include <net_datalink.h>
#include <net_stack.h>
#include <util/khash.h>
#include <sys/socket.h>
#include <ByteOrder.h>
typedef enum { typedef enum {
CLOSED, CLOSED,
@ -51,3 +61,23 @@ struct tcp_header {
#define TCP_FLG_RST 0x04 // ReSeT #define TCP_FLG_RST 0x04 // ReSeT
#define TCP_FLG_SYN 0x02 // SYNchronize #define TCP_FLG_SYN 0x02 // SYNchronize
#define TCP_FLG_FIN 0x01 // FINish #define TCP_FLG_FIN 0x01 // FINish
struct tcp_connection_key {
const sockaddr *local;
const sockaddr *peer;
};
extern net_domain *gDomain;
extern net_address_module_info *gAddressModule;
extern net_buffer_module_info *gBufferModule;
extern net_datalink_module_info *gDatalinkModule;
extern net_stack_module_info *gStackModule;
extern hash_table *gConnectionHash;
extern benaphore gConnectionLock;
status_t add_tcp_header(net_buffer *buffer, uint16 flags, uint32 sequence,
uint32 ack, uint16 advertisedWindow);
#endif TCP_H

View File

@ -147,7 +147,7 @@ static UdpEndpointManager *sUdpEndpointManager;
static net_domain *sDomain; static net_domain *sDomain;
static net_address_module_info *sAddressModule; static net_address_module_info *sAddressModule;
net_buffer_module_info *sBufferModule; net_buffer_module_info *gBufferModule;
static net_datalink_module_info *sDatalinkModule; static net_datalink_module_info *sDatalinkModule;
static net_stack_module_info *sStackModule; static net_stack_module_info *sStackModule;
@ -495,7 +495,7 @@ UdpEndpointManager::ReceiveData(net_buffer *buffer)
if (buffer->size > udpLength) { if (buffer->size > udpLength) {
TRACE(("buffer %p is too long (%lu instead of %u), trimming it.\n", TRACE(("buffer %p is too long (%lu instead of %u), trimming it.\n",
buffer, buffer->size, udpLength)); buffer, buffer->size, udpLength));
sBufferModule->trim(buffer, udpLength); gBufferModule->trim(buffer, udpLength);
} }
if (header.udp_checksum != 0) { if (header.udp_checksum != 0) {
@ -508,7 +508,7 @@ UdpEndpointManager::ReceiveData(net_buffer *buffer)
<< header.udp_length << header.udp_length
// peculiar but correct: UDP-len is used twice for checksum // peculiar but correct: UDP-len is used twice for checksum
// (as it is already contained in udp_header) // (as it is already contained in udp_header)
<< Checksum::BufferHelper(buffer, sBufferModule); << Checksum::BufferHelper(buffer, gBufferModule);
uint16 sum = udpChecksum; uint16 sum = udpChecksum;
if (sum != 0) { if (sum != 0) {
TRACE(("buffer %p has bad checksum (%u), we drop it!\n", buffer, sum)); TRACE(("buffer %p has bad checksum (%u), we drop it!\n", buffer, sum));
@ -813,7 +813,7 @@ UdpEndpoint::SendData(net_buffer *buffer, net_route *route)
<< (uint16)htons(buffer->size) << (uint16)htons(buffer->size)
// peculiar but correct: UDP-len is used twice for checksum // peculiar but correct: UDP-len is used twice for checksum
// (as it is already contained in udp_header) // (as it is already contained in udp_header)
<< Checksum::BufferHelper(buffer, sBufferModule); << Checksum::BufferHelper(buffer, gBufferModule);
header.udp_checksum = udpChecksum; header.udp_checksum = udpChecksum;
if (header.udp_checksum == 0) if (header.udp_checksum == 0)
header.udp_checksum = 0xFFFF; header.udp_checksum = 0xFFFF;
@ -851,7 +851,7 @@ UdpEndpoint::FetchData(size_t numBytes, uint32 flags, net_buffer **_buffer)
if (numBytes < buffer->size) { if (numBytes < buffer->size) {
// discard any data behind the amount requested // discard any data behind the amount requested
sBufferModule->trim(buffer, numBytes); gBufferModule->trim(buffer, numBytes);
// TODO: we should indicate MSG_TRUNC to application! // TODO: we should indicate MSG_TRUNC to application!
} }
@ -866,7 +866,7 @@ UdpEndpoint::StoreData(net_buffer *_buffer)
{ {
TRACE(("buffer %p passed to endpoint with (%s)\n", _buffer, TRACE(("buffer %p passed to endpoint with (%s)\n", _buffer,
AddressString(sDomain, (sockaddr *)&socket->address, true).Data())); AddressString(sDomain, (sockaddr *)&socket->address, true).Data()));
net_buffer *buffer = sBufferModule->clone(_buffer, false); net_buffer *buffer = gBufferModule->clone(_buffer, false);
if (buffer == NULL) if (buffer == NULL)
return B_NO_MEMORY; return B_NO_MEMORY;
@ -874,7 +874,7 @@ UdpEndpoint::StoreData(net_buffer *_buffer)
if (status >= B_OK) if (status >= B_OK)
sStackModule->notify_socket(socket, B_SELECT_READ, BytesAvailable()); sStackModule->notify_socket(socket, B_SELECT_READ, BytesAvailable());
else else
sBufferModule->free(buffer); gBufferModule->free(buffer);
return status; return status;
} }
@ -1098,7 +1098,7 @@ init_udp()
status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule); status = get_module(NET_STACK_MODULE_NAME, (module_info **)&sStackModule);
if (status < B_OK) if (status < B_OK)
return status; return status;
status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&sBufferModule); status = get_module(NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule);
if (status < B_OK) if (status < B_OK)
goto err1; goto err1;
status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule); status = get_module(NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule);