diff --git a/kernel/net/ipv4.c b/kernel/net/ipv4.c index 9fbdac72..b4e1f6e0 100644 --- a/kernel/net/ipv4.c +++ b/kernel/net/ipv4.c @@ -112,6 +112,55 @@ uint16_t calculate_ipv4_checksum(struct ipv4_packet * p) { return ~(sum & 0xFFFF) & 0xFFFF; } +uint16_t calculate_tcp_checksum(struct tcp_check_header * p, struct tcp_header * h, void * d, size_t payload_size) { + uint32_t sum = 0; + uint16_t * s = (uint16_t *)p; + + /* TODO: Checksums for options? */ + for (int i = 0; i < 6; ++i) { + sum += ntohs(s[i]); + if (sum > 0xFFFF) { + sum = (sum >> 16) + (sum & 0xFFFF); + } + } + + s = (uint16_t *)h; + for (int i = 0; i < 10; ++i) { + sum += ntohs(s[i]); + if (sum > 0xFFFF) { + sum = (sum >> 16) + (sum & 0xFFFF); + } + } + + uint16_t d_words = payload_size / 2; + + s = (uint16_t *)d; + for (unsigned int i = 0; i < d_words; ++i) { + sum += ntohs(s[i]); + if (sum > 0xFFFF) { + sum = (sum >> 16) + (sum & 0xFFFF); + } + } + + if (d_words * 2 != payload_size) { + uint8_t * t = (uint8_t *)d; + uint8_t tmp[2]; + tmp[0] = t[d_words * sizeof(uint16_t)]; + tmp[1] = 0; + + uint16_t * f = (uint16_t *)tmp; + + sum += ntohs(f[0]); + if (sum > 0xFFFF) { + sum = (sum >> 16) + (sum & 0xFFFF); + } + } + + return ~(sum & 0xFFFF) & 0xFFFF; +} + + + static void icmp_handle(struct ipv4_packet * packet, const char * src, const char * dest, fs_node_t * nic) { struct icmp_header * header = (void*)&packet->payload; if (header->type == 8 && header->code == 0) { @@ -150,6 +199,69 @@ extern void net_sock_add(sock_t * sock, void * frame); static hashmap_t * udp_sockets = NULL; static hashmap_t * tcp_sockets = NULL; +#define TCP_FLAGS_FIN (1 << 0) +#define TCP_FLAGS_SYN (1 << 1) +#define TCP_FLAGS_RES (1 << 2) +#define TCP_FLAGS_PSH (1 << 3) +#define TCP_FLAGS_ACK (1 << 4) +#define TCP_FLAGS_URG (1 << 5) +#define TCP_FLAGS_ECE (1 << 6) +#define TCP_FLAGS_CWR (1 << 7) +#define TCP_FLAGS_NS (1 << 8) +#define DATA_OFFSET_5 (0x5 << 12) + +static void tcp_ack(fs_node_t * nic, sock_t * sock, struct ipv4_packet * packet) { + sock->priv[1] = 2; + + size_t total_length = sizeof(struct ipv4_packet) + sizeof(struct tcp_header); + struct tcp_header * tcp = (struct tcp_header*)&packet->payload; + + struct ipv4_packet * response = malloc(total_length); + response->length = htons(total_length); + response->destination = packet->source; + response->source = ((struct EthernetDevice*)nic->device)->ipv4_addr; + response->ttl = 64; + response->protocol = IPV4_PROT_TCP; + response->ident = 0; + response->flags_fragment = htons(0x4000); + response->version_ihl = 0x45; + response->dscp_ecn = 0; + response->checksum = 0; + response->checksum = htons(calculate_ipv4_checksum(response)); + + sock->priv[2] = 1; + + /* Stick TCP header into payload */ + struct tcp_header * tcp_header = (struct tcp_header*)&response->payload; + tcp_header->source_port = htons(sock->priv[0]); + tcp_header->destination_port = tcp->source_port; + tcp_header->seq_number = htonl(sock->priv[2]); + tcp_header->ack_number = tcp->seq_number; + tcp_header->flags = htons((TCP_FLAGS_ACK) | 0x5000); + tcp_header->window_size = htons(1548-54); + tcp_header->checksum = 0; + tcp_header->urgent = 0; + + sock->priv[2]++; + + /* Calculate checksum */ + struct tcp_check_header check_hd = { + .source = response->source, + .destination = response->destination, + .zeros = 0, + .protocol = IPV4_PROT_TCP, + .tcp_len = htons(sizeof(struct tcp_header)), + }; + + tcp_header->checksum = htons(calculate_tcp_checksum(&check_hd, tcp_header, NULL, 0)); + + net_eth_send((struct EthernetDevice*)nic->device, ntohs(response->length), response, ETHERNET_TYPE_IPV4, ETHERNET_BROADCAST_MAC); + + void * tmp = malloc(ntohs(packet->length)); + memcpy(tmp, packet, ntohs(packet->length)); + net_sock_add(sock, tmp); +} + void net_ipv4_handle(struct ipv4_packet * packet, fs_node_t * nic) { char dest[16]; @@ -164,7 +276,7 @@ void net_ipv4_handle(struct ipv4_packet * packet, fs_node_t * nic) { break; case IPV4_PROT_UDP: { uint16_t dest_port = ntohs(((uint16_t*)&packet->payload)[1]); - printf("net: ipv4: %s: %s -> %s udp %d to %d\n", nic->name, src, dest, ntohs(((uint16_t*)&packet->payload)[0]), ntohs(((uint16_t*)&packet->payload)[1])); + printf("net: ipv4: %s: %s -> %s udp %d to %d\n", nic->name, src, dest, ntohs(((uint16_t*)&packet->payload)[0]), dest_port); if (udp_sockets && hashmap_has(udp_sockets, (void*)(uintptr_t)dest_port)) { printf("net: udp: received and have a waiting endpoint!\n"); void * tmp = malloc(ntohs(packet->length)); @@ -174,9 +286,29 @@ void net_ipv4_handle(struct ipv4_packet * packet, fs_node_t * nic) { } break; } - case IPV4_PROT_TCP: - printf("net: ipv4: %s: %s -> %s tcp %d to %d\n", nic->name, src, dest, ntohs(((uint16_t*)&packet->payload)[0]), ntohs(((uint16_t*)&packet->payload)[1])); + case IPV4_PROT_TCP: { + uint16_t dest_port = ntohs(((uint16_t*)&packet->payload)[1]); + printf("net: ipv4: %s: %s -> %s tcp %d to %d\n", nic->name, src, dest, ntohs(((uint16_t*)&packet->payload)[0]), dest_port); + if (tcp_sockets && hashmap_has(tcp_sockets, (void*)(uintptr_t)dest_port)) { + printf("net: tcp: received and have a waiting endpoint!\n"); + /* What kind of packet is this? Is it something we were expecting? */ + sock_t * sock = hashmap_get(tcp_sockets, (void*)(uintptr_t)dest_port); + struct tcp_header * tcp = (struct tcp_header*)&packet->payload; + + if (sock->priv[1] == 1) { + /* Awaiting SYN ACK, is this one? */ + if ((ntohs(tcp->flags) & (TCP_FLAGS_SYN | TCP_FLAGS_ACK)) == (TCP_FLAGS_SYN | TCP_FLAGS_ACK)) { + printf("tcp: synack\n"); + tcp_ack(nic, sock, packet); + } + } else if (sock->priv[1] == 2) { + void * tmp = malloc(ntohs(packet->length)); + memcpy(tmp, packet, ntohs(packet->length)); + net_sock_add(sock, tmp); + } + } break; + } } } @@ -344,54 +476,6 @@ static long sock_tcp_recv(sock_t * sock, struct msghdr * msg, int flags) { return 0; } -uint16_t calculate_tcp_checksum(struct tcp_check_header * p, struct tcp_header * h, void * d, size_t payload_size) { - uint32_t sum = 0; - uint16_t * s = (uint16_t *)p; - - /* TODO: Checksums for options? */ - for (int i = 0; i < 6; ++i) { - sum += ntohs(s[i]); - if (sum > 0xFFFF) { - sum = (sum >> 16) + (sum & 0xFFFF); - } - } - - s = (uint16_t *)h; - for (int i = 0; i < 10; ++i) { - sum += ntohs(s[i]); - if (sum > 0xFFFF) { - sum = (sum >> 16) + (sum & 0xFFFF); - } - } - - uint16_t d_words = payload_size / 2; - - s = (uint16_t *)d; - for (unsigned int i = 0; i < d_words; ++i) { - sum += ntohs(s[i]); - if (sum > 0xFFFF) { - sum = (sum >> 16) + (sum & 0xFFFF); - } - } - - if (d_words * 2 != payload_size) { - uint8_t * t = (uint8_t *)d; - uint8_t tmp[2]; - tmp[0] = t[d_words * sizeof(uint16_t)]; - tmp[1] = 0; - - uint16_t * f = (uint16_t *)tmp; - - sum += ntohs(f[0]); - if (sum > 0xFFFF) { - sum = (sum >> 16) + (sum & 0xFFFF); - } - } - - return ~(sum & 0xFFFF) & 0xFFFF; -} - - static long sock_tcp_connect(sock_t * sock, const struct sockaddr *addr, socklen_t addrlen) { const struct sockaddr_in * dest = (const struct sockaddr_in *)addr; char deststr[16]; @@ -455,8 +539,22 @@ static long sock_tcp_connect(sock_t * sock, const struct sockaddr *addr, socklen free(response); - /* wait for signal that we connected or timed out */ + printf("tcp: waiting for connect to finish\n"); + /* wait for signal that we connected or timed out */ + struct ipv4_packet * data = net_sock_get(sock); + printf("tcp: connect complete\n"); + free(data); + + return 0; +} + +ssize_t sock_tcp_read(fs_node_t *node, off_t offset, size_t size, uint8_t *buffer) { + printf("tcp: read into buffer of %zu bytes\n", size); + return 0; +} +ssize_t sock_tcp_write(fs_node_t *node, off_t offset, size_t size, uint8_t *buffer) { + printf("tcp: write of %zu bytes\n", size); return 0; } @@ -467,7 +565,11 @@ static int tcp_socket(void) { sock->sock_send = sock_tcp_send; sock->sock_close = sock_tcp_close; sock->sock_connect = sock_tcp_connect; - return process_append_fd((process_t *)this_core->current_process, (fs_node_t *)sock); + sock->_fnode.read = sock_tcp_read; + sock->_fnode.write = sock_tcp_write; + int fd = process_append_fd((process_t *)this_core->current_process, (fs_node_t *)sock); + FD_MODE(fd) = 03; + return fd; } long net_ipv4_socket(int type, int protocol) {