Add +=, -=, *= and /=

This commit is contained in:
Rui Ueyama 2020-10-07 20:17:30 +09:00
parent a4fea2ba3e
commit 01a94c04aa
3 changed files with 52 additions and 2 deletions

41
parse.c
View File

@ -80,6 +80,8 @@ static Node *assign(Token **rest, Token *tok);
static Node *equality(Token **rest, Token *tok);
static Node *relational(Token **rest, Token *tok);
static Node *add(Token **rest, Token *tok);
static Node *new_add(Node *lhs, Node *rhs, Token *tok);
static Node *new_sub(Node *lhs, Node *rhs, Token *tok);
static Node *mul(Token **rest, Token *tok);
static Node *cast(Token **rest, Token *tok);
static Type *struct_decl(Token **rest, Token *tok);
@ -670,13 +672,50 @@ static Node *expr(Token **rest, Token *tok) {
return node;
}
// assign = equality ("=" assign)?
// Convert `A op= B` to `tmp = &A, *tmp = *tmp op B`
// where tmp is a fresh pointer variable.
static Node *to_assign(Node *binary) {
add_type(binary->lhs);
add_type(binary->rhs);
Token *tok = binary->tok;
Obj *var = new_lvar("", pointer_to(binary->lhs->ty));
Node *expr1 = new_binary(ND_ASSIGN, new_var_node(var, tok),
new_unary(ND_ADDR, binary->lhs, tok), tok);
Node *expr2 =
new_binary(ND_ASSIGN,
new_unary(ND_DEREF, new_var_node(var, tok), tok),
new_binary(binary->kind,
new_unary(ND_DEREF, new_var_node(var, tok), tok),
binary->rhs,
tok),
tok);
return new_binary(ND_COMMA, expr1, expr2, tok);
}
// assign = equality (assign-op assign)?
// assign-op = "=" | "+=" | "-=" | "*=" | "/="
static Node *assign(Token **rest, Token *tok) {
Node *node = equality(&tok, tok);
if (equal(tok, "="))
return new_binary(ND_ASSIGN, node, assign(rest, tok->next), tok);
if (equal(tok, "+="))
return to_assign(new_add(node, assign(rest, tok->next), tok));
if (equal(tok, "-="))
return to_assign(new_sub(node, assign(rest, tok->next), tok));
if (equal(tok, "*="))
return to_assign(new_binary(ND_MUL, node, assign(rest, tok->next), tok));
if (equal(tok, "/="))
return to_assign(new_binary(ND_DIV, node, assign(rest, tok->next), tok));
*rest = tok;
return node;
}

View File

@ -33,6 +33,15 @@ int main() {
ASSERT(0, 1073741824 * 100 / 100);
ASSERT(7, ({ int i=2; i+=5; i; }));
ASSERT(7, ({ int i=2; i+=5; }));
ASSERT(3, ({ int i=5; i-=2; i; }));
ASSERT(3, ({ int i=5; i-=2; }));
ASSERT(6, ({ int i=3; i*=2; i; }));
ASSERT(6, ({ int i=3; i*=2; }));
ASSERT(3, ({ int i=6; i/=2; i; }));
ASSERT(3, ({ int i=6; i/=2; }));
printf("OK\n");
return 0;
}

View File

@ -115,7 +115,9 @@ static int from_hex(char c) {
// Read a punctuator token from p and returns its length.
static int read_punct(char *p) {
static char *kw[] = {"==", "!=", "<=", ">=", "->"};
static char *kw[] = {
"==", "!=", "<=", ">=", "->", "+=", "-=", "*=", "/=",
};
for (int i = 0; i < sizeof(kw) / sizeof(*kw); i++)
if (startswith(p, kw[i]))