diff --git a/src/add-ons/kernel/network/protocols/tcp/TCPConnection.h b/src/add-ons/kernel/network/protocols/tcp/TCPConnection.h index b9cc67b39a..ac7a90f588 100644 --- a/src/add-ons/kernel/network/protocols/tcp/TCPConnection.h +++ b/src/add-ons/kernel/network/protocols/tcp/TCPConnection.h @@ -61,17 +61,21 @@ private: class TCPSegment { public: - TCPSegment(uint32 sequenceNumber, uint32 size, bigtime_t timeout); + + struct list_link link; + + TCPSegment(net_buffer *buffer, uint32 sequenceNumber, bigtime_t timeout); ~TCPSegment(); + net_buffer *fBuffer; bigtime_t fTime; uint32 fSequenceNumber; - uint32 fAcknowledgementNumber; net_timer fTimer; bool fTimedOut; }; - status_t Send(uint16 flags, bool empty); + status_t SendQueuedData(uint16 flags, bool empty); + status_t EnqueueReceivedData(net_buffer *buffer, uint32 sequenceNumber); status_t Reset(uint32 sequenceNum, uint32 acknowledgeNum); static void TimeWait(struct net_timer *timer, void *data); @@ -88,6 +92,9 @@ private: net_buffer *fSendBuffer; net_buffer *fReceiveBuffer; + struct list fReorderQueue; + struct list fWaitQueue; + TCPConnection *fHashLink; tcp_state fState; status_t fError; @@ -98,15 +105,17 @@ private: }; -TCPConnection::TCPSegment::TCPSegment(uint32 sequenceNumber, uint32 size, bigtime_t timeout) +TCPConnection::TCPSegment::TCPSegment(net_buffer *buffer, uint32 sequenceNumber, bigtime_t timeout) : + fBuffer(buffer), fTime(system_time()), fSequenceNumber(sequenceNumber), - fAcknowledgementNumber(sequenceNumber+size), fTimedOut(false) { - sStackModule->init_timer(&fTimer, &TCPConnection::ResendSegment, this); - sStackModule->set_timer(&fTimer, timeout); + if (timeout > 0) { + sStackModule->init_timer(&fTimer, &TCPConnection::ResendSegment, this); + sStackModule->set_timer(&fTimer, timeout); + } } @@ -133,6 +142,8 @@ TCPConnection::TCPConnection(net_socket *socket) { benaphore_init(&fLock, "TCPConnection"); sStackModule->init_timer(&fTimer, TimeWait, this); + list_init(&fReorderQueue); + list_init(&fWaitQueue); } @@ -173,12 +184,12 @@ TCPConnection::Close() nextState = FIN_WAIT1; if (fState == CLOSE_WAIT) nextState = LAST_ACK; - status_t status = Send(TCP_FLG_FIN | TCP_FLG_ACK, false); + status_t status = SendQueuedData(TCP_FLG_FIN | TCP_FLG_ACK, false); if (status != B_OK) return status; fState = nextState; TRACE(("TCP: %p.Close(): Entering state %d\n", this, fState)); - // need to wait until fState returns to CLOSED + //do i need to wait until fState returns to CLOSED? return B_OK; } @@ -265,7 +276,7 @@ TCPConnection::Connect(const struct sockaddr *address) TRACE(("TCP: Connect(): starting 3-way handshake...\n")); // send SYN - status = Send(TCP_FLG_SYN, false); + status = SendQueuedData(TCP_FLG_SYN, false); if (status != B_OK) return status; fState = SYN_SENT; @@ -370,13 +381,13 @@ TCPConnection::SendData(net_buffer *buffer) if (fSendBuffer == NULL) { fSendBuffer = buffer; fNextByteToWrite += bufSize; - return Send(TCP_FLG_ACK, false); + return SendQueuedData(TCP_FLG_ACK, false); } else { status_t status = sBufferModule->merge(fSendBuffer, buffer, true); if (status != B_OK) return status; fNextByteToWrite += bufSize; - return Send(TCP_FLG_ACK, false); + return SendQueuedData(TCP_FLG_ACK, false); } } @@ -406,10 +417,28 @@ TCPConnection::SendAvailable() status_t -TCPConnection::ReadData(size_t numBytes, uint32 flags, net_buffer **_buffer) +TCPConnection::ReadData(size_t numBytes, uint32 flags, net_buffer **buffer) { TRACE(("TCP:%p.ReadData()\n", this)); - return B_ERROR; + + BenaphoreLocker lock(&fLock); + + // must be in a synchronous state + if (fState != ESTABLISHED || fState != FIN_WAIT1 || fState != FIN_WAIT2) { + // is this correct semantics? + return B_ERROR; + } + + if (fError != B_OK) + return fError; + + if (fReceiveBuffer->size < numBytes) + numBytes = fReceiveBuffer->size; + *buffer = sBufferModule->split(fReceiveBuffer, numBytes); + if (*buffer != NULL) + return B_OK; + else + return B_ERROR; } @@ -425,6 +454,75 @@ TCPConnection::ReadAvailable() } +status_t +TCPConnection::EnqueueReceivedData(net_buffer *buffer, uint32 sequenceNumber) +{ + TRACE(("TCP:%p.EnqueueReceivedData(%p, %u)\n", this, buffer, sequenceNumber)); + status_t status; + + if (sequenceNumber == fNextByteExpected) { + // first check if the received buffer meets up with the first + // segment in the ReorderQueue + TCPSegment *next; + while ((next = (TCPSegment *)list_get_first_item(&fReorderQueue)) != NULL) { + if (sequenceNumber + buffer->size >= next->fSequenceNumber) { + if (sequenceNumber + buffer->size > next->fSequenceNumber) { + status = sBufferModule->trim(buffer, sequenceNumber - next->fSequenceNumber); + if (status != B_OK) + return status; + } + status = sBufferModule->merge(buffer, next->fBuffer, true); + if (status != B_OK) + return status; + list_remove_item(&fReorderQueue, next); + delete next; + } else + break; + } + + fNextByteExpected += buffer->size; + if (fReceiveBuffer == NULL) + fReceiveBuffer = buffer; + else { + status = sBufferModule->merge(fReceiveBuffer, buffer, true); + if (status < B_OK) { + fNextByteExpected -= buffer->size; + return status; + } + } + } else { + // add this buffer into the ReorderQueue in the correct place + // creating a new TCPSegment if necessary + TCPSegment *next = NULL; + do { + next = (TCPSegment *)list_get_next_item(&fReorderQueue, next); + if (next != NULL && next->fSequenceNumber < sequenceNumber) + continue; + if (next != NULL && sequenceNumber + buffer->size >= next->fSequenceNumber) { + // merge the new buffer with the next buffer + if (sequenceNumber + buffer->size > next->fSequenceNumber) { + status = sBufferModule->trim(buffer, sequenceNumber - next->fSequenceNumber); + if (status != B_OK) + return status; + } + status = sBufferModule->merge(buffer, next->fBuffer, true); + if (status != B_OK) + return status; + next->fBuffer = buffer; + next->fSequenceNumber = sequenceNumber; + break; + } + TCPSegment *segment = new(std::nothrow) TCPSegment(buffer, sequenceNumber, -1); + if (next == NULL) + list_add_item(&fReorderQueue, segment); + else + list_insert_item_before(&fReorderQueue, next, segment); + } while (next != NULL); + } + return B_OK; +} + + status_t TCPConnection::ReceiveData(net_buffer *buffer) { @@ -442,7 +540,8 @@ TCPConnection::ReceiveData(net_buffer *buffer) status_t status = B_OK; uint32 byteAckd = ntohl(header.acknowledge_num); uint32 byteRcvd = ntohl(header.sequence_num); - uint32 payloadLength = buffer->size - ((uint32)header.header_length << 2); + uint32 headerLength = (uint32)header.header_length << 2; + uint32 payloadLength = buffer->size - headerLength; TRACE(("TCP: ReceiveData(): Connection in state %d received packet %p with flags %X!\n", fState, buffer, header.flags)); switch (fState) { @@ -483,11 +582,6 @@ TCPConnection::ReceiveData(net_buffer *buffer) nextState = SYN_RCVD; flags |= TCP_FLG_SYN; } - status = Send(flags, false); - if (status == B_OK) - fState = nextState; - else - return status; } break; case SYN_RCVD: @@ -553,18 +647,35 @@ TCPConnection::ReceiveData(net_buffer *buffer) break; } } - if (fState != CLOSING && fState != LAST_ACK) - status = Send(flags | TCP_FLG_ACK, false); - if (status != B_OK) - return status; - fState = nextState; + flags |= TCP_FLG_ACK; } else { // out-of-order packet received. remind the other side of where we are - return Send(TCP_FLG_ACK, true); + return SendQueuedData(TCP_FLG_ACK, true); } break; } TRACE(("TCP %p.ReceiveData():Entering state %d\n", this, fState)); + + // state machine is done switching states and the data is good. + // put it in the receive buffer + // TODO: This isn't the most efficient way to do it, and will need to be changed + // to deal with Silly Window Syndrome + bufferHeader.Remove(headerLength); + if (buffer->size > 0) { + status = EnqueueReceivedData(buffer, byteRcvd); + if (status != B_OK) + return status; + } else + sBufferModule->free(buffer); + + if (fState != CLOSING && fState != LAST_ACK) { + status = SendQueuedData(flags, false); + if (status != B_OK) + return status; + } + + fState = nextState; + return B_OK; } @@ -611,9 +722,9 @@ TCPConnection::ResendSegment(struct net_timer *timer, void *data) The fLock benaphore must be held before calling. */ status_t -TCPConnection::Send(uint16 flags, bool empty) +TCPConnection::SendQueuedData(uint16 flags, bool empty) { - TRACE(("TCP:%p.Send(%X,%s)\n", this, flags, empty ? "1" : "0")); + TRACE(("TCP:%p.SendQueuedData(%X,%s)\n", this, flags, empty ? "1" : "0")); if (fRoute == NULL) return B_ERROR; @@ -635,9 +746,9 @@ TCPConnection::Send(uint16 flags, bool empty) sAddressModule->set_to((sockaddr *)&buffer->source, (sockaddr *)&socket->address); sAddressModule->set_to((sockaddr *)&buffer->destination, (sockaddr *)&socket->peer); - TRACE(("TCP:%p.Send() to address %s\n", this, + TRACE(("TCP:%p.SendQueuedData() to address %s\n", this, AddressString(sDomain, (sockaddr *)&buffer->destination, true).Data())); - TRACE(("TCP:%p.Send() from address %s\n", this, + TRACE(("TCP:%p.SendQueuedData() from address %s\n", this, AddressString(sDomain, (sockaddr *)&buffer->source, true).Data())); uint16 advWin = TCP_MAX_RECV_BUF - (fNextByteExpected - fNextByteToRead);