Add constant expression

This commit is contained in:
Rui Ueyama 2019-08-17 10:38:20 +09:00
parent 447ee098c5
commit 79f5de21eb
2 changed files with 117 additions and 16 deletions

93
parse.c
View File

@ -91,6 +91,7 @@ static Node *expr_stmt(Token **rest, Token *tok);
static Node *expr(Token **rest, Token *tok);
static Node *assign(Token **rest, Token *tok);
static Node *logor(Token **rest, Token *tok);
static int64_t const_expr(Token **rest, Token *tok);
static Node *conditional(Token **rest, Token *tok);
static Node *logand(Token **rest, Token *tok);
static Node *bitor(Token **rest, Token *tok);
@ -249,12 +250,6 @@ static Type *find_typedef(Token *tok) {
return NULL;
}
static long get_number(Token *tok) {
if (tok->kind != TK_NUM)
error_tok(tok, "expected a number");
return tok->val;
}
static void push_tag_scope(Token *tok, Type *ty) {
TagScope *sc = calloc(1, sizeof(TagScope));
sc->name = strndup(tok->loc, tok->len);
@ -415,15 +410,15 @@ static Type *func_params(Token **rest, Token *tok, Type *ty) {
return ty;
}
// array-dimensions = num? "]" type-suffix
// array-dimensions = const-expr? "]" type-suffix
static Type *array_dimensions(Token **rest, Token *tok, Type *ty) {
if (equal(tok, "]")) {
ty = type_suffix(rest, tok->next, ty);
return array_of(ty, -1);
}
int sz = get_number(tok);
tok = skip(tok->next, "]");
int sz = const_expr(&tok, tok);
tok = skip(tok, "]");
ty = type_suffix(rest, tok, ty);
return array_of(ty, sz);
}
@ -524,10 +519,8 @@ static Type *enum_specifier(Token **rest, Token *tok) {
char *name = get_ident(tok);
tok = tok->next;
if (equal(tok, "=")) {
val = get_number(tok->next);
tok = tok->next->next;
}
if (equal(tok, "="))
val = const_expr(&tok, tok->next);
VarScope *sc = push_scope(name);
sc->enum_ty = ty;
@ -590,7 +583,7 @@ static bool is_typename(Token *tok) {
// stmt = "return" expr ";"
// | "if" "(" expr ")" stmt ("else" stmt)?
// | "switch" "(" expr ")" stmt
// | "case" num ":" stmt
// | "case" const-expr ":" stmt
// | "default" ":" stmt
// | "for" "(" expr-stmt expr? ";" expr? ")" stmt
// | "while" "(" expr ")" stmt
@ -645,10 +638,10 @@ static Node *stmt(Token **rest, Token *tok) {
if (equal(tok, "case")) {
if (!current_switch)
error_tok(tok, "stray case");
int val = get_number(tok->next);
Node *node = new_node(ND_CASE, tok);
tok = skip(tok->next->next, ":");
int val = const_expr(&tok, tok->next);
tok = skip(tok, ":");
node->label = new_unique_name();
node->lhs = stmt(rest, tok);
node->val = val;
@ -820,6 +813,74 @@ static Node *expr(Token **rest, Token *tok) {
return node;
}
// Evaluate a given node as a constant expression.
static int64_t eval(Node *node) {
add_type(node);
switch (node->kind) {
case ND_ADD:
return eval(node->lhs) + eval(node->rhs);
case ND_SUB:
return eval(node->lhs) - eval(node->rhs);
case ND_MUL:
return eval(node->lhs) * eval(node->rhs);
case ND_DIV:
return eval(node->lhs) / eval(node->rhs);
case ND_NEG:
return -eval(node->lhs);
case ND_MOD:
return eval(node->lhs) % eval(node->rhs);
case ND_BITAND:
return eval(node->lhs) & eval(node->rhs);
case ND_BITOR:
return eval(node->lhs) | eval(node->rhs);
case ND_BITXOR:
return eval(node->lhs) ^ eval(node->rhs);
case ND_SHL:
return eval(node->lhs) << eval(node->rhs);
case ND_SHR:
return eval(node->lhs) >> eval(node->rhs);
case ND_EQ:
return eval(node->lhs) == eval(node->rhs);
case ND_NE:
return eval(node->lhs) != eval(node->rhs);
case ND_LT:
return eval(node->lhs) < eval(node->rhs);
case ND_LE:
return eval(node->lhs) <= eval(node->rhs);
case ND_COND:
return eval(node->cond) ? eval(node->then) : eval(node->els);
case ND_COMMA:
return eval(node->rhs);
case ND_NOT:
return !eval(node->lhs);
case ND_BITNOT:
return ~eval(node->lhs);
case ND_LOGAND:
return eval(node->lhs) && eval(node->rhs);
case ND_LOGOR:
return eval(node->lhs) || eval(node->rhs);
case ND_CAST:
if (is_integer(node->ty)) {
switch (node->ty->size) {
case 1: return (uint8_t)eval(node->lhs);
case 2: return (uint16_t)eval(node->lhs);
case 4: return (uint32_t)eval(node->lhs);
}
}
return eval(node->lhs);
case ND_NUM:
return node->val;
}
error_tok(node->tok, "not a compile-time constant");
}
static int64_t const_expr(Token **rest, Token *tok) {
Node *node = conditional(rest, tok);
return eval(node);
}
// Convert `A op= B` to `tmp = &A, *tmp = *tmp op B`
// where tmp is a fresh pointer variable.
static Node *to_assign(Node *binary) {

40
test/constexpr.c Normal file
View File

@ -0,0 +1,40 @@
#include "test.h"
int main() {
ASSERT(10, ({ enum { ten=1+2+3+4 }; ten; }));
ASSERT(1, ({ int i=0; switch(3) { case 5-2+0*3: i++; } i; }));
ASSERT(8, ({ int x[1+1]; sizeof(x); }));
ASSERT(6, ({ char x[8-2]; sizeof(x); }));
ASSERT(6, ({ char x[2*3]; sizeof(x); }));
ASSERT(3, ({ char x[12/4]; sizeof(x); }));
ASSERT(2, ({ char x[12%10]; sizeof(x); }));
ASSERT(0b100, ({ char x[0b110&0b101]; sizeof(x); }));
ASSERT(0b111, ({ char x[0b110|0b101]; sizeof(x); }));
ASSERT(0b110, ({ char x[0b111^0b001]; sizeof(x); }));
ASSERT(4, ({ char x[1<<2]; sizeof(x); }));
ASSERT(2, ({ char x[4>>1]; sizeof(x); }));
ASSERT(2, ({ char x[(1==1)+1]; sizeof(x); }));
ASSERT(1, ({ char x[(1!=1)+1]; sizeof(x); }));
ASSERT(1, ({ char x[(1<1)+1]; sizeof(x); }));
ASSERT(2, ({ char x[(1<=1)+1]; sizeof(x); }));
ASSERT(2, ({ char x[1?2:3]; sizeof(x); }));
ASSERT(3, ({ char x[0?2:3]; sizeof(x); }));
ASSERT(3, ({ char x[(1,3)]; sizeof(x); }));
ASSERT(2, ({ char x[!0+1]; sizeof(x); }));
ASSERT(1, ({ char x[!1+1]; sizeof(x); }));
ASSERT(2, ({ char x[~-3]; sizeof(x); }));
ASSERT(2, ({ char x[(5||6)+1]; sizeof(x); }));
ASSERT(1, ({ char x[(0||0)+1]; sizeof(x); }));
ASSERT(2, ({ char x[(1&&1)+1]; sizeof(x); }));
ASSERT(1, ({ char x[(1&&0)+1]; sizeof(x); }));
ASSERT(3, ({ char x[(int)3]; sizeof(x); }));
ASSERT(15, ({ char x[(char)0xffffff0f]; sizeof(x); }));
ASSERT(0x10f, ({ char x[(short)0xffff010f]; sizeof(x); }));
ASSERT(4, ({ char x[(int)0xfffffffffff+5]; sizeof(x); }));
ASSERT(8, ({ char x[(int*)0+2]; sizeof(x); }));
ASSERT(12, ({ char x[(int*)16-1]; sizeof(x); }));
ASSERT(3, ({ char x[(int*)16-(int*)4]; sizeof(x); }));
printf("OK\n");
return 0;
}