Make pointer arithmetic work

This commit is contained in:
Rui Ueyama 2020-09-04 13:39:06 +09:00
parent 863e2b8de2
commit a6bc4ab101
4 changed files with 147 additions and 7 deletions

View File

@ -7,6 +7,7 @@
#include <stdlib.h>
#include <string.h>
typedef struct Type Type;
typedef struct Node Node;
//
@ -86,6 +87,7 @@ typedef enum {
struct Node {
NodeKind kind; // Node kind
Node *next; // Next node
Type *ty; // Type, e.g. int or pointer to int
Token *tok; // Representative token
Node *lhs; // Left-hand side
@ -107,6 +109,25 @@ struct Node {
Function *parse(Token *tok);
//
// type.c
//
typedef enum {
TY_INT,
TY_PTR,
} TypeKind;
struct Type {
TypeKind kind;
Type *base;
};
extern Type *ty_int;
bool is_integer(Type *ty);
void add_type(Node *node);
//
// codegen.c
//

65
parse.c
View File

@ -23,6 +23,7 @@
Obj *locals;
static Node *compound_stmt(Token **rest, Token *tok);
static Node *stmt(Token **rest, Token *tok);
static Node *expr_stmt(Token **rest, Token *tok);
static Node *expr(Token **rest, Token *tok);
static Node *assign(Token **rest, Token *tok);
@ -146,8 +147,10 @@ static Node *compound_stmt(Token **rest, Token *tok) {
Node head = {};
Node *cur = &head;
while (!equal(tok, "}"))
while (!equal(tok, "}")) {
cur = cur->next = stmt(&tok, tok);
add_type(cur);
}
node->body = head.next;
*rest = tok->next;
@ -237,6 +240,62 @@ static Node *relational(Token **rest, Token *tok) {
}
}
// In C, `+` operator is overloaded to perform the pointer arithmetic.
// If p is a pointer, p+n adds not n but sizeof(*p)*n to the value of p,
// so that p+n points to the location n elements (not bytes) ahead of p.
// In other words, we need to scale an integer value before adding to a
// pointer value. This function takes care of the scaling.
static Node *new_add(Node *lhs, Node *rhs, Token *tok) {
add_type(lhs);
add_type(rhs);
// num + num
if (is_integer(lhs->ty) && is_integer(rhs->ty))
return new_binary(ND_ADD, lhs, rhs, tok);
if (lhs->ty->base && rhs->ty->base)
error_tok(tok, "invalid operands");
// Canonicalize `num + ptr` to `ptr + num`.
if (!lhs->ty->base && rhs->ty->base) {
Node *tmp = lhs;
lhs = rhs;
rhs = tmp;
}
// ptr + num
rhs = new_binary(ND_MUL, rhs, new_num(8, tok), tok);
return new_binary(ND_ADD, lhs, rhs, tok);
}
// Like `+`, `-` is overloaded for the pointer type.
static Node *new_sub(Node *lhs, Node *rhs, Token *tok) {
add_type(lhs);
add_type(rhs);
// num - num
if (is_integer(lhs->ty) && is_integer(rhs->ty))
return new_binary(ND_SUB, lhs, rhs, tok);
// ptr - num
if (lhs->ty->base && is_integer(rhs->ty)) {
rhs = new_binary(ND_MUL, rhs, new_num(8, tok), tok);
add_type(rhs);
Node *node = new_binary(ND_SUB, lhs, rhs, tok);
node->ty = lhs->ty;
return node;
}
// ptr - ptr, which returns how many elements are between the two.
if (lhs->ty->base && rhs->ty->base) {
Node *node = new_binary(ND_SUB, lhs, rhs, tok);
node->ty = ty_int;
return new_binary(ND_DIV, node, new_num(8, tok), tok);
}
error_tok(tok, "invalid operands");
}
// add = mul ("+" mul | "-" mul)*
static Node *add(Token **rest, Token *tok) {
Node *node = mul(&tok, tok);
@ -245,12 +304,12 @@ static Node *add(Token **rest, Token *tok) {
Token *start = tok;
if (equal(tok, "+")) {
node = new_binary(ND_ADD, node, mul(&tok, tok->next), start);
node = new_add(node, mul(&tok, tok->next), start);
continue;
}
if (equal(tok, "-")) {
node = new_binary(ND_SUB, node, mul(&tok, tok->next), start);
node = new_sub(node, mul(&tok, tok->next), start);
continue;
}

10
test.sh
View File

@ -76,10 +76,12 @@ assert 10 '{ i=0; while(i<10) { i=i+1; } return i; }'
assert 3 '{ x=3; return *&x; }'
assert 3 '{ x=3; y=&x; z=&y; return **z; }'
assert 5 '{ x=3; y=5; return *(&x+8); }'
assert 3 '{ x=3; y=5; return *(&y-8); }'
assert 5 '{ x=3; y=5; return *(&x+1); }'
assert 3 '{ x=3; y=5; return *(&y-1); }'
assert 5 '{ x=3; y=5; return *(&x-(-1)); }'
assert 5 '{ x=3; y=&x; *y=5; return x; }'
assert 7 '{ x=3; y=5; *(&x+8)=7; return y; }'
assert 7 '{ x=3; y=5; *(&y-8)=7; return x; }'
assert 7 '{ x=3; y=5; *(&x+1)=7; return y; }'
assert 7 '{ x=3; y=5; *(&y-2+1)=7; return x; }'
assert 5 '{ x=3; return (&x+2)-&x+3; }'
echo OK

58
type.c Normal file
View File

@ -0,0 +1,58 @@
#include "chibicc.h"
Type *ty_int = &(Type){TY_INT};
bool is_integer(Type *ty) {
return ty->kind == TY_INT;
}
Type *pointer_to(Type *base) {
Type *ty = calloc(1, sizeof(Type));
ty->kind = TY_PTR;
ty->base = base;
return ty;
}
void add_type(Node *node) {
if (!node || node->ty)
return;
add_type(node->lhs);
add_type(node->rhs);
add_type(node->cond);
add_type(node->then);
add_type(node->els);
add_type(node->init);
add_type(node->inc);
for (Node *n = node->body; n; n = n->next)
add_type(n);
switch (node->kind) {
case ND_ADD:
case ND_SUB:
case ND_MUL:
case ND_DIV:
case ND_NEG:
case ND_ASSIGN:
node->ty = node->lhs->ty;
return;
case ND_EQ:
case ND_NE:
case ND_LT:
case ND_LE:
case ND_VAR:
case ND_NUM:
node->ty = ty_int;
return;
case ND_ADDR:
node->ty = pointer_to(node->lhs->ty);
return;
case ND_DEREF:
if (node->lhs->ty->kind == TY_PTR)
node->ty = node->lhs->ty->base;
else
node->ty = ty_int;
return;
}
}