net: try to actually track received packet sizes...

This commit is contained in:
K. Lange 2021-10-25 18:27:08 +09:00
parent 073e43384c
commit 49342b7996
6 changed files with 68 additions and 41 deletions

View File

@ -10,7 +10,7 @@
#define MAC_FORMAT "%02x:%02x:%02x:%02x:%02x:%02x"
#define FORMAT_MAC(m) (m)[0], (m)[1], (m)[2], (m)[3], (m)[4], (m)[5]
void net_eth_handle(struct ethernet_packet * frame, fs_node_t * nic);
void net_eth_handle(struct ethernet_packet * frame, fs_node_t * nic, size_t size);
struct EthernetDevice {
char if_name[32];

View File

@ -27,21 +27,27 @@ struct ethernet_packet {
extern spin_lock_t net_raw_sockets_lock;
extern list_t * net_raw_sockets_list;
extern void net_ipv4_handle(void * packet, fs_node_t * nic);
extern void net_ipv4_handle(void * packet, fs_node_t * nic, size_t);
extern void net_arp_handle(void * packet, fs_node_t * nic);
void net_eth_handle(struct ethernet_packet * frame, fs_node_t * nic) {
void net_eth_handle(struct ethernet_packet * frame, fs_node_t * nic, size_t size) {
struct EthernetDevice * nic_eth = nic->device;
if (size < sizeof(struct ethernet_packet)) {
dprintf("eth: %s: invalid ethernet frame (too small)\n",
nic_eth->if_name);
return;
}
spin_lock(net_raw_sockets_lock);
foreach(node, net_raw_sockets_list) {
sock_t * sock = node->value;
if (!sock->_fnode.device || sock->_fnode.device == nic) {
net_sock_add(sock, frame, 4096);
net_sock_add(sock, frame, size);
}
}
spin_unlock(net_raw_sockets_lock);
struct EthernetDevice * nic_eth = nic->device;
if (!memcmp(frame->destination, nic_eth->mac, 6) || !memcmp(frame->destination, ETHERNET_BROADCAST_MAC, 6)) {
/* Now pass the frame to the appropriate handler... */
switch (ntohs(frame->type)) {
@ -54,7 +60,7 @@ void net_eth_handle(struct ethernet_packet * frame, fs_node_t * nic) {
if (packet->source != 0xFFFFFFFF) {
net_arp_cache_add(nic->device, packet->source, frame->source, 0);
}
net_ipv4_handle(packet, nic);
net_ipv4_handle(packet, nic, size - sizeof(struct ethernet_packet));
break;
}
}

View File

@ -15,6 +15,7 @@
#include <kernel/vfs.h>
#include <kernel/time.h>
#include <kernel/misc.h>
#include <kernel/assert.h>
#include <kernel/net/netif.h>
#include <kernel/net/eth.h>
@ -158,7 +159,9 @@ static void icmp_handle(struct ipv4_packet * packet, const char * src, const cha
/* Is this a PING request? */
if (header->type == 8 && header->code == 0) {
printf("net: ping with %d bytes of payload\n", ntohs(packet->length));
if (ntohs(packet->length) & 1) packet->length++;
if (ntohs(packet->length) & 1) {
packet->length = htons(ntohs(packet->length) + 1);
}
struct ipv4_packet * response = malloc(ntohs(packet->length));
memcpy(response, packet, ntohs(packet->length));
@ -203,15 +206,19 @@ static long sock_icmp_recv(sock_t * sock, struct msghdr * msg, int flags) {
}
if (msg->msg_iovlen == 0) return 0;
struct ipv4_packet * data = net_sock_get(sock);
if (!data) {
return -EINTR;
char * packet = net_sock_get(sock);
if (!packet) return -EINTR;
size_t packet_size = *(size_t*)packet;
struct ipv4_packet * data = (struct ipv4_packet*)(packet + sizeof(size_t));
if (packet_size > msg->msg_iov[0].iov_len) {
dprintf("ICMP recv too big for vector\n");
packet_size = msg->msg_iov[0].iov_len;
}
long resp = ntohs(data->length);
memcpy(msg->msg_iov[0].iov_base, data, resp);
free(data);
return resp;
memcpy(msg->msg_iov[0].iov_base, data, packet_size);
free(packet);
return packet_size;
}
static long sock_icmp_send(sock_t * sock, const struct msghdr *msg, int flags) {
@ -372,7 +379,11 @@ static int tcp_ack(fs_node_t * nic, sock_t * sock, struct ipv4_packet * packet,
return retval;
}
void net_ipv4_handle(struct ipv4_packet * packet, fs_node_t * nic) {
void net_ipv4_handle(struct ipv4_packet * packet, fs_node_t * nic, size_t size) {
if (size < sizeof(struct ipv4_packet)) {
dprintf("ipv4: Incoming packet is too small.\n");
}
char dest[16];
char src[16];
@ -514,12 +525,9 @@ static long sock_udp_recv(sock_t * sock, struct msghdr * msg, int flags) {
}
if (msg->msg_iovlen == 0) return 0;
struct ipv4_packet * data = net_sock_get(sock);
if (!data) {
/* We need to figure out why, but that's... complicated for now. */
return -EINTR;
}
char * packet = net_sock_get(sock);
if (!packet) return -EINTR;
struct ipv4_packet * data = (struct ipv4_packet*)(packet + sizeof(size_t));
printf("udp: got response, size is %u - sizeof(ipv4) - sizeof(udp) = %lu\n",
ntohs(data->length), ntohs(data->length) - sizeof(struct ipv4_packet) - sizeof(struct udp_packet));
@ -528,7 +536,7 @@ static long sock_udp_recv(sock_t * sock, struct msghdr * msg, int flags) {
printf("udp: data copied to iov 0, return length?\n");
long resp = ntohs(data->length) - sizeof(struct ipv4_packet) - sizeof(struct udp_packet);
free(data);
free(packet);
return resp;
}
@ -674,11 +682,23 @@ static long sock_tcp_recv(sock_t * sock, struct msghdr * msg, int flags) {
process_wait_nodes((process_t *)this_core->current_process, (fs_node_t*[]){(fs_node_t*)sock,NULL}, 200);
}
struct ipv4_packet * data = net_sock_get(sock);
if (!data) {
return -EINTR;
char * packet = net_sock_get(sock);
if (!packet) return -EINTR;
struct ipv4_packet * data = (struct ipv4_packet*)(packet + sizeof(size_t));
size_t packet_size = *(size_t*)packet;
unsigned long resp = ntohs(data->length);
if (resp != packet_size) {
dprintf("packet size does not match: %zu %zu\n", resp, packet_size);
resp = packet_size;
}
unsigned long resp = ntohs(data->length) - sizeof(struct ipv4_packet) - sizeof(struct tcp_header);
if (resp < sizeof(struct ipv4_packet) + sizeof(struct tcp_header)) {
dprintf("Invalid receive data?\n");
assert(0);
}
resp -= sizeof(struct ipv4_packet) + sizeof(struct tcp_header);
if (resp > (unsigned long)msg->msg_iov[0].iov_len) {
memcpy(msg->msg_iov[0].iov_base, data->payload + sizeof(struct tcp_header),msg->msg_iov[0].iov_len);
@ -687,12 +707,12 @@ static long sock_tcp_recv(sock_t * sock, struct msghdr * msg, int flags) {
sock->unread = resp;
sock->buf = malloc(resp);
memcpy(sock->buf, data->payload + sizeof(struct tcp_header) + msg->msg_iov[0].iov_len, resp);
free(data);
free(packet);
return msg->msg_iov[0].iov_len;
}
memcpy(msg->msg_iov[0].iov_base, data->payload + sizeof(struct tcp_header), resp);
free(data);
free(packet);
return resp;
}
@ -793,12 +813,11 @@ static long sock_tcp_connect(sock_t * sock, const struct sockaddr *addr, socklen
printf("tcp: queue should have data now (len = %lu), trying to read\n", sock->rx_queue->length);
/* wait for signal that we connected or timed out */
struct ipv4_packet * data = net_sock_get(sock);
if (!data) {
return -EINTR;
}
char * packet = net_sock_get(sock);
if (!packet) return -EINTR;
//struct ipv4_packet * data = packet + sizeof(size_t);
printf("tcp: connect complete\n");
free(data);
free(packet);
return 0;
}

View File

@ -92,7 +92,7 @@ static ssize_t write_loop(fs_node_t *node, off_t offset, size_t size, uint8_t *b
nic->counts.rx_bytes += size;
nic->counts.tx_bytes += size;
net_eth_handle((void*)buffer, node);
net_eth_handle((void*)buffer, node, size);
return size;
}

View File

@ -46,8 +46,9 @@ void net_sock_alert(sock_t * sock) {
void net_sock_add(sock_t * sock, void * frame, size_t size) {
spin_lock(sock->rx_lock);
char * bleh = malloc(size);
memcpy(bleh, frame, size);
char * bleh = malloc(size + sizeof(size_t));
*(size_t*)bleh = size;
memcpy(bleh + sizeof(size_t), frame, size);
list_insert(sock->rx_queue, bleh);
wakeup_queue(sock->rx_wait);
net_sock_alert(sock);
@ -126,10 +127,11 @@ static long sock_raw_recv(sock_t * sock, struct msghdr * msg, int flags) {
return -ENOTSUP;
}
if (msg->msg_iovlen == 0) return 0;
if (msg->msg_iov[0].iov_len != 4096) return -EINVAL;
void * data = net_sock_get(sock);
char * data = net_sock_get(sock);
if (!data) return -EINTR;
memcpy(msg->msg_iov[0].iov_base, data, 4096);
size_t packet_size = *(size_t*)data;
if (msg->msg_iov[0].iov_len < packet_size) return -EINVAL;
memcpy(msg->msg_iov[0].iov_base, data + sizeof(size_t), packet_size);
free(data);
return 4096;
}

View File

@ -175,7 +175,7 @@ static void e1000_queuer(void * data) {
if (!(nic->rx[i].errors & (0x97))) {
nic->counts.rx_count++;
nic->counts.rx_bytes += nic->rx[i].length;
net_eth_handle((void*)nic->rx_virt[i], nic->eth.device_node);
net_eth_handle((void*)nic->rx_virt[i], nic->eth.device_node, nic->rx[i].length);
} else {
printf("error bits set in packet: %x\n", nic->rx[i].errors);
}