diff --git a/chibicc.h b/chibicc.h index b22223b..3d59afb 100644 --- a/chibicc.h +++ b/chibicc.h @@ -11,7 +11,8 @@ // typedef enum { - TK_PUNCT, // Keywords or punctuators + TK_IDENT, // Identifiers + TK_PUNCT, // Punctuators TK_NUM, // Numeric literals TK_EOF, // End-of-file markers } TokenKind; @@ -47,7 +48,9 @@ typedef enum { ND_NE, // != ND_LT, // < ND_LE, // <= + ND_ASSIGN, // = ND_EXPR_STMT, // Expression statement + ND_VAR, // Variable ND_NUM, // Integer } NodeKind; @@ -58,6 +61,7 @@ struct Node { Node *next; // Next node Node *lhs; // Left-hand side Node *rhs; // Right-hand side + char name; // Used if kind == ND_VAR int val; // Used if kind == ND_NUM }; diff --git a/codegen.c b/codegen.c index 1fd9542..adc4924 100644 --- a/codegen.c +++ b/codegen.c @@ -12,6 +12,19 @@ static void pop(char *arg) { depth--; } +// Compute the absolute address of a given node. +// It's an error if a given node does not reside in memory. +static void gen_addr(Node *node) { + if (node->kind == ND_VAR) { + int offset = (node->name - 'a' + 1) * 8; + printf(" lea %d(%%rbp), %%rax\n", -offset); + return; + } + + error("not an lvalue"); +} + +// Generate code for a given node. static void gen_expr(Node *node) { switch (node->kind) { case ND_NUM: @@ -21,6 +34,17 @@ static void gen_expr(Node *node) { gen_expr(node->lhs); printf(" neg %%rax\n"); return; + case ND_VAR: + gen_addr(node); + printf(" mov (%%rax), %%rax\n"); + return; + case ND_ASSIGN: + gen_addr(node->lhs); + push(); + gen_expr(node->rhs); + pop("%rdi"); + printf(" mov %%rax, (%%rdi)\n"); + return; } gen_expr(node->rhs); @@ -77,10 +101,17 @@ void codegen(Node *node) { printf(" .globl main\n"); printf("main:\n"); + // Prologue + printf(" push %%rbp\n"); + printf(" mov %%rsp, %%rbp\n"); + printf(" sub $208, %%rsp\n"); + for (Node *n = node; n; n = n->next) { gen_stmt(n); assert(depth == 0); } + printf(" mov %%rbp, %%rsp\n"); + printf(" pop %%rbp\n"); printf(" ret\n"); } diff --git a/parse.c b/parse.c index 012f169..f9b8e58 100644 --- a/parse.c +++ b/parse.c @@ -2,6 +2,7 @@ static Node *expr(Token **rest, Token *tok); static Node *expr_stmt(Token **rest, Token *tok); +static Node *assign(Token **rest, Token *tok); static Node *equality(Token **rest, Token *tok); static Node *relational(Token **rest, Token *tok); static Node *add(Token **rest, Token *tok); @@ -34,6 +35,12 @@ static Node *new_num(int val) { return node; } +static Node *new_var_node(char name) { + Node *node = new_node(ND_VAR); + node->name = name; + return node; +} + // stmt = expr-stmt static Node *stmt(Token **rest, Token *tok) { return expr_stmt(rest, tok); @@ -46,9 +53,18 @@ static Node *expr_stmt(Token **rest, Token *tok) { return node; } -// expr = equality +// expr = assign static Node *expr(Token **rest, Token *tok) { - return equality(rest, tok); + return assign(rest, tok); +} + +// assign = equality ("=" assign)? +static Node *assign(Token **rest, Token *tok) { + Node *node = equality(&tok, tok); + if (equal(tok, "=")) + node = new_binary(ND_ASSIGN, node, assign(&tok, tok->next)); + *rest = tok; + return node; } // equality = relational ("==" relational | "!=" relational)* @@ -153,7 +169,7 @@ static Node *unary(Token **rest, Token *tok) { return primary(rest, tok); } -// primary = "(" expr ")" | num +// primary = "(" expr ")" | ident | num static Node *primary(Token **rest, Token *tok) { if (equal(tok, "(")) { Node *node = expr(&tok, tok->next); @@ -161,6 +177,12 @@ static Node *primary(Token **rest, Token *tok) { return node; } + if (tok->kind == TK_IDENT) { + Node *node = new_var_node(*tok->loc); + *rest = tok->next; + return node; + } + if (tok->kind == TK_NUM) { Node *node = new_num(tok->val); *rest = tok->next; diff --git a/test.sh b/test.sh index a15edd3..fb6888a 100755 --- a/test.sh +++ b/test.sh @@ -48,4 +48,8 @@ assert 0 '1>=2;' assert 3 '1; 2; 3;' +assert 3 'a=3; a;' +assert 8 'a=3; z=5; a+z;' +assert 6 'a=b=3; a+b;' + echo OK diff --git a/tokenize.c b/tokenize.c index cc05255..dfecaf6 100644 --- a/tokenize.c +++ b/tokenize.c @@ -91,6 +91,13 @@ Token *tokenize(char *p) { continue; } + // Identifier + if ('a' <= *p && *p <= 'z') { + cur = cur->next = new_token(TK_IDENT, p, p + 1); + p++; + continue; + } + // Punctuators int punct_len = read_punct(p); if (punct_len) {