Add ==, !=, <= and >= operators

This commit is contained in:
Rui Ueyama 2019-08-03 16:15:23 +09:00
parent bf9ab52860
commit 25b4b85b88
2 changed files with 116 additions and 4 deletions

101
main.c
View File

@ -89,6 +89,19 @@ static Token *new_token(TokenKind kind, char *start, char *end) {
return tok;
}
static bool startswith(char *p, char *q) {
return strncmp(p, q, strlen(q)) == 0;
}
// Read a punctuator token from p and returns its length.
static int read_punct(char *p) {
if (startswith(p, "==") || startswith(p, "!=") ||
startswith(p, "<=") || startswith(p, ">="))
return 2;
return ispunct(*p) ? 1 : 0;
}
// Tokenize `current_input` and returns new tokens.
static Token *tokenize(void) {
char *p = current_input;
@ -112,9 +125,10 @@ static Token *tokenize(void) {
}
// Punctuators
if (ispunct(*p)) {
cur = cur->next = new_token(TK_PUNCT, p, p + 1);
p++;
int punct_len = read_punct(p);
if (punct_len) {
cur = cur->next = new_token(TK_PUNCT, p, p + punct_len);
p += cur->len;
continue;
}
@ -135,6 +149,10 @@ typedef enum {
ND_MUL, // *
ND_DIV, // /
ND_NEG, // unary -
ND_EQ, // ==
ND_NE, // !=
ND_LT, // <
ND_LE, // <=
ND_NUM, // Integer
} NodeKind;
@ -173,12 +191,70 @@ static Node *new_num(int val) {
}
static Node *expr(Token **rest, Token *tok);
static Node *equality(Token **rest, Token *tok);
static Node *relational(Token **rest, Token *tok);
static Node *add(Token **rest, Token *tok);
static Node *mul(Token **rest, Token *tok);
static Node *unary(Token **rest, Token *tok);
static Node *primary(Token **rest, Token *tok);
// expr = mul ("+" mul | "-" mul)*
// expr = equality
static Node *expr(Token **rest, Token *tok) {
return equality(rest, tok);
}
// equality = relational ("==" relational | "!=" relational)*
static Node *equality(Token **rest, Token *tok) {
Node *node = relational(&tok, tok);
for (;;) {
if (equal(tok, "==")) {
node = new_binary(ND_EQ, node, relational(&tok, tok->next));
continue;
}
if (equal(tok, "!=")) {
node = new_binary(ND_NE, node, relational(&tok, tok->next));
continue;
}
*rest = tok;
return node;
}
}
// relational = add ("<" add | "<=" add | ">" add | ">=" add)*
static Node *relational(Token **rest, Token *tok) {
Node *node = add(&tok, tok);
for (;;) {
if (equal(tok, "<")) {
node = new_binary(ND_LT, node, add(&tok, tok->next));
continue;
}
if (equal(tok, "<=")) {
node = new_binary(ND_LE, node, add(&tok, tok->next));
continue;
}
if (equal(tok, ">")) {
node = new_binary(ND_LT, add(&tok, tok->next), node);
continue;
}
if (equal(tok, ">=")) {
node = new_binary(ND_LE, add(&tok, tok->next), node);
continue;
}
*rest = tok;
return node;
}
}
// add = mul ("+" mul | "-" mul)*
static Node *add(Token **rest, Token *tok) {
Node *node = mul(&tok, tok);
for (;;) {
@ -292,6 +368,23 @@ static void gen_expr(Node *node) {
printf(" cqo\n");
printf(" idiv %%rdi\n");
return;
case ND_EQ:
case ND_NE:
case ND_LT:
case ND_LE:
printf(" cmp %%rdi, %%rax\n");
if (node->kind == ND_EQ)
printf(" sete %%al\n");
else if (node->kind == ND_NE)
printf(" setne %%al\n");
else if (node->kind == ND_LT)
printf(" setl %%al\n");
else if (node->kind == ND_LE)
printf(" setle %%al\n");
printf(" movzb %%al, %%rax\n");
return;
}
error("invalid expression");

19
test.sh
View File

@ -27,4 +27,23 @@ assert 10 '-10+20'
assert 10 '- -10'
assert 10 '- - +10'
assert 0 '0==1'
assert 1 '42==42'
assert 1 '0!=1'
assert 0 '42!=42'
assert 1 '0<1'
assert 0 '1<1'
assert 0 '2<1'
assert 1 '0<=1'
assert 1 '1<=1'
assert 0 '2<=1'
assert 1 '1>0'
assert 0 '1>1'
assert 0 '1>2'
assert 1 '1>=0'
assert 1 '1>=1'
assert 0 '1>=2'
echo OK