diff --git a/main.c b/main.c index ae300a0..0be4717 100644 --- a/main.c +++ b/main.c @@ -89,6 +89,19 @@ static Token *new_token(TokenKind kind, char *start, char *end) { return tok; } +static bool startswith(char *p, char *q) { + return strncmp(p, q, strlen(q)) == 0; +} + +// Read a punctuator token from p and returns its length. +static int read_punct(char *p) { + if (startswith(p, "==") || startswith(p, "!=") || + startswith(p, "<=") || startswith(p, ">=")) + return 2; + + return ispunct(*p) ? 1 : 0; +} + // Tokenize `current_input` and returns new tokens. static Token *tokenize(void) { char *p = current_input; @@ -112,9 +125,10 @@ static Token *tokenize(void) { } // Punctuators - if (ispunct(*p)) { - cur = cur->next = new_token(TK_PUNCT, p, p + 1); - p++; + int punct_len = read_punct(p); + if (punct_len) { + cur = cur->next = new_token(TK_PUNCT, p, p + punct_len); + p += cur->len; continue; } @@ -135,6 +149,10 @@ typedef enum { ND_MUL, // * ND_DIV, // / ND_NEG, // unary - + ND_EQ, // == + ND_NE, // != + ND_LT, // < + ND_LE, // <= ND_NUM, // Integer } NodeKind; @@ -173,12 +191,70 @@ static Node *new_num(int val) { } static Node *expr(Token **rest, Token *tok); +static Node *equality(Token **rest, Token *tok); +static Node *relational(Token **rest, Token *tok); +static Node *add(Token **rest, Token *tok); static Node *mul(Token **rest, Token *tok); static Node *unary(Token **rest, Token *tok); static Node *primary(Token **rest, Token *tok); -// expr = mul ("+" mul | "-" mul)* +// expr = equality static Node *expr(Token **rest, Token *tok) { + return equality(rest, tok); +} + +// equality = relational ("==" relational | "!=" relational)* +static Node *equality(Token **rest, Token *tok) { + Node *node = relational(&tok, tok); + + for (;;) { + if (equal(tok, "==")) { + node = new_binary(ND_EQ, node, relational(&tok, tok->next)); + continue; + } + + if (equal(tok, "!=")) { + node = new_binary(ND_NE, node, relational(&tok, tok->next)); + continue; + } + + *rest = tok; + return node; + } +} + +// relational = add ("<" add | "<=" add | ">" add | ">=" add)* +static Node *relational(Token **rest, Token *tok) { + Node *node = add(&tok, tok); + + for (;;) { + if (equal(tok, "<")) { + node = new_binary(ND_LT, node, add(&tok, tok->next)); + continue; + } + + if (equal(tok, "<=")) { + node = new_binary(ND_LE, node, add(&tok, tok->next)); + continue; + } + + if (equal(tok, ">")) { + node = new_binary(ND_LT, add(&tok, tok->next), node); + continue; + } + + if (equal(tok, ">=")) { + node = new_binary(ND_LE, add(&tok, tok->next), node); + continue; + } + + *rest = tok; + return node; + } +} + +// add = mul ("+" mul | "-" mul)* +static Node *add(Token **rest, Token *tok) { Node *node = mul(&tok, tok); for (;;) { @@ -292,6 +368,23 @@ static void gen_expr(Node *node) { printf(" cqo\n"); printf(" idiv %%rdi\n"); return; + case ND_EQ: + case ND_NE: + case ND_LT: + case ND_LE: + printf(" cmp %%rdi, %%rax\n"); + + if (node->kind == ND_EQ) + printf(" sete %%al\n"); + else if (node->kind == ND_NE) + printf(" setne %%al\n"); + else if (node->kind == ND_LT) + printf(" setl %%al\n"); + else if (node->kind == ND_LE) + printf(" setle %%al\n"); + + printf(" movzb %%al, %%rax\n"); + return; } error("invalid expression"); diff --git a/test.sh b/test.sh index 3011b04..d15a4b9 100755 --- a/test.sh +++ b/test.sh @@ -27,4 +27,23 @@ assert 10 '-10+20' assert 10 '- -10' assert 10 '- - +10' +assert 0 '0==1' +assert 1 '42==42' +assert 1 '0!=1' +assert 0 '42!=42' + +assert 1 '0<1' +assert 0 '1<1' +assert 0 '2<1' +assert 1 '0<=1' +assert 1 '1<=1' +assert 0 '2<=1' + +assert 1 '1>0' +assert 0 '1>1' +assert 0 '1>2' +assert 1 '1>=0' +assert 1 '1>=1' +assert 0 '1>=2' + echo OK