#include "chibicc.h" static FILE *output_file; static int depth; static char *argreg8[] = {"%dil", "%sil", "%dl", "%cl", "%r8b", "%r9b"}; static char *argreg16[] = {"%di", "%si", "%dx", "%cx", "%r8w", "%r9w"}; static char *argreg32[] = {"%edi", "%esi", "%edx", "%ecx", "%r8d", "%r9d"}; static char *argreg64[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"}; static Obj *current_fn; static void gen_expr(Node *node); static void gen_stmt(Node *node); static void println(char *fmt, ...) { va_list ap; va_start(ap, fmt); vfprintf(output_file, fmt, ap); va_end(ap); fprintf(output_file, "\n"); } static int count(void) { static int i = 1; return i++; } static void push(void) { println(" push %%rax"); depth++; } static void pop(char *arg) { println(" pop %s", arg); depth--; } static void pushf(void) { println(" sub $8, %%rsp"); println(" movsd %%xmm0, (%%rsp)"); depth++; } static void popf(int reg) { println(" movsd (%%rsp), %%xmm%d", reg); println(" add $8, %%rsp"); depth--; } // Round up `n` to the nearest multiple of `align`. For instance, // align_to(5, 8) returns 8 and align_to(11, 8) returns 16. int align_to(int n, int align) { return (n + align - 1) / align * align; } // Compute the absolute address of a given node. // It's an error if a given node does not reside in memory. static void gen_addr(Node *node) { switch (node->kind) { case ND_VAR: // Local variable if (node->var->is_local) { println(" lea %d(%%rbp), %%rax", node->var->offset); return; } // Here, we generate an absolute address of a function or a global // variable. Even though they exist at a certain address at runtime, // their addresses are not known at link-time for the following // two reasons. // // - Address randomization: Executables are loaded to memory as a // whole but it is not known what address they are loaded to. // Therefore, at link-time, relative address in the same // exectuable (i.e. the distance between two functions in the // same executable) is known, but the absolute address is not // known. // // - Dynamic linking: Dynamic shared objects (DSOs) or .so files // are loaded to memory alongside an executable at runtime and // linked by the runtime loader in memory. We know nothing // about addresses of global stuff that may be defined by DSOs // until the runtime relocation is complete. // // In order to deal with the former case, we use RIP-relative // addressing, denoted by `(%rip)`. For the latter, we obtain an // address of a stuff that may be in a shared object file from the // Global Offset Table using `@GOTPCREL(%rip)` notation. // Function if (node->ty->kind == TY_FUNC) { if (node->var->is_definition) println(" lea %s(%%rip), %%rax", node->var->name); else println(" mov %s@GOTPCREL(%%rip), %%rax", node->var->name); return; } // Global variable println(" lea %s(%%rip), %%rax", node->var->name); return; case ND_DEREF: gen_expr(node->lhs); return; case ND_COMMA: gen_expr(node->lhs); gen_addr(node->rhs); return; case ND_MEMBER: gen_addr(node->lhs); println(" add $%d, %%rax", node->member->offset); return; } error_tok(node->tok, "not an lvalue"); } // Load a value from where %rax is pointing to. static void load(Type *ty) { switch (ty->kind) { case TY_ARRAY: case TY_STRUCT: case TY_UNION: case TY_FUNC: // If it is an array, do not attempt to load a value to the // register because in general we can't load an entire array to a // register. As a result, the result of an evaluation of an array // becomes not the array itself but the address of the array. // This is where "array is automatically converted to a pointer to // the first element of the array in C" occurs. return; case TY_FLOAT: println(" movss (%%rax), %%xmm0"); return; case TY_DOUBLE: println(" movsd (%%rax), %%xmm0"); return; } char *insn = ty->is_unsigned ? "movz" : "movs"; // When we load a char or a short value to a register, we always // extend them to the size of int, so we can assume the lower half of // a register always contains a valid value. The upper half of a // register for char, short and int may contain garbage. When we load // a long value to a register, it simply occupies the entire register. if (ty->size == 1) println(" %sbl (%%rax), %%eax", insn); else if (ty->size == 2) println(" %swl (%%rax), %%eax", insn); else if (ty->size == 4) println(" movsxd (%%rax), %%rax"); else println(" mov (%%rax), %%rax"); } // Store %rax to an address that the stack top is pointing to. static void store(Type *ty) { pop("%rdi"); switch (ty->kind) { case TY_STRUCT: case TY_UNION: for (int i = 0; i < ty->size; i++) { println(" mov %d(%%rax), %%r8b", i); println(" mov %%r8b, %d(%%rdi)", i); } return; case TY_FLOAT: println(" movss %%xmm0, (%%rdi)"); return; case TY_DOUBLE: println(" movsd %%xmm0, (%%rdi)"); return; } if (ty->size == 1) println(" mov %%al, (%%rdi)"); else if (ty->size == 2) println(" mov %%ax, (%%rdi)"); else if (ty->size == 4) println(" mov %%eax, (%%rdi)"); else println(" mov %%rax, (%%rdi)"); } static void cmp_zero(Type *ty) { switch (ty->kind) { case TY_FLOAT: println(" xorps %%xmm1, %%xmm1"); println(" ucomiss %%xmm1, %%xmm0"); return; case TY_DOUBLE: println(" xorpd %%xmm1, %%xmm1"); println(" ucomisd %%xmm1, %%xmm0"); return; } if (is_integer(ty) && ty->size <= 4) println(" cmp $0, %%eax"); else println(" cmp $0, %%rax"); } enum { I8, I16, I32, I64, U8, U16, U32, U64, F32, F64 }; static int getTypeId(Type *ty) { switch (ty->kind) { case TY_CHAR: return ty->is_unsigned ? U8 : I8; case TY_SHORT: return ty->is_unsigned ? U16 : I16; case TY_INT: return ty->is_unsigned ? U32 : I32; case TY_LONG: return ty->is_unsigned ? U64 : I64; case TY_FLOAT: return F32; case TY_DOUBLE: return F64; } return U64; } // The table for type casts static char i32i8[] = "movsbl %al, %eax"; static char i32u8[] = "movzbl %al, %eax"; static char i32i16[] = "movswl %ax, %eax"; static char i32u16[] = "movzwl %ax, %eax"; static char i32f32[] = "cvtsi2ssl %eax, %xmm0"; static char i32i64[] = "movsxd %eax, %rax"; static char i32f64[] = "cvtsi2sdl %eax, %xmm0"; static char u32f32[] = "mov %eax, %eax; cvtsi2ssq %rax, %xmm0"; static char u32i64[] = "mov %eax, %eax"; static char u32f64[] = "mov %eax, %eax; cvtsi2sdq %rax, %xmm0"; static char i64f32[] = "cvtsi2ssq %rax, %xmm0"; static char i64f64[] = "cvtsi2sdq %rax, %xmm0"; static char u64f32[] = "cvtsi2ssq %rax, %xmm0"; static char u64f64[] = "test %rax,%rax; js 1f; pxor %xmm0,%xmm0; cvtsi2sd %rax,%xmm0; jmp 2f; " "1: mov %rax,%rdi; and $1,%eax; pxor %xmm0,%xmm0; shr %rdi; " "or %rax,%rdi; cvtsi2sd %rdi,%xmm0; addsd %xmm0,%xmm0; 2:"; static char f32i8[] = "cvttss2sil %xmm0, %eax; movsbl %al, %eax"; static char f32u8[] = "cvttss2sil %xmm0, %eax; movzbl %al, %eax"; static char f32i16[] = "cvttss2sil %xmm0, %eax; movswl %ax, %eax"; static char f32u16[] = "cvttss2sil %xmm0, %eax; movzwl %ax, %eax"; static char f32i32[] = "cvttss2sil %xmm0, %eax"; static char f32u32[] = "cvttss2siq %xmm0, %rax"; static char f32i64[] = "cvttss2siq %xmm0, %rax"; static char f32u64[] = "cvttss2siq %xmm0, %rax"; static char f32f64[] = "cvtss2sd %xmm0, %xmm0"; static char f64i8[] = "cvttsd2sil %xmm0, %eax; movsbl %al, %eax"; static char f64u8[] = "cvttsd2sil %xmm0, %eax; movzbl %al, %eax"; static char f64i16[] = "cvttsd2sil %xmm0, %eax; movswl %ax, %eax"; static char f64u16[] = "cvttsd2sil %xmm0, %eax; movzwl %ax, %eax"; static char f64i32[] = "cvttsd2sil %xmm0, %eax"; static char f64u32[] = "cvttsd2siq %xmm0, %rax"; static char f64f32[] = "cvtsd2ss %xmm0, %xmm0"; static char f64i64[] = "cvttsd2siq %xmm0, %rax"; static char f64u64[] = "cvttsd2siq %xmm0, %rax"; static char *cast_table[][10] = { // i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 {NULL, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i8 {i32i8, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i16 {i32i8, i32i16, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i32 {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, i64f32, i64f64}, // i64 {i32i8, NULL, NULL, i32i64, NULL, NULL, NULL, i32i64, i32f32, i32f64}, // u8 {i32i8, i32i16, NULL, i32i64, i32u8, NULL, NULL, i32i64, i32f32, i32f64}, // u16 {i32i8, i32i16, NULL, u32i64, i32u8, i32u16, NULL, u32i64, u32f32, u32f64}, // u32 {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, u64f32, u64f64}, // u64 {f32i8, f32i16, f32i32, f32i64, f32u8, f32u16, f32u32, f32u64, NULL, f32f64}, // f32 {f64i8, f64i16, f64i32, f64i64, f64u8, f64u16, f64u32, f64u64, f64f32, NULL}, // f64 }; static void cast(Type *from, Type *to) { if (to->kind == TY_VOID) return; if (to->kind == TY_BOOL) { cmp_zero(from); println(" setne %%al"); println(" movzx %%al, %%eax"); return; } int t1 = getTypeId(from); int t2 = getTypeId(to); if (cast_table[t1][t2]) println(" %s", cast_table[t1][t2]); } static void push_args(Node *args) { if (args) { push_args(args->next); gen_expr(args); if (is_flonum(args->ty)) pushf(); else push(); } } // Generate code for a given node. static void gen_expr(Node *node) { println(" .loc %d %d", node->tok->file->file_no, node->tok->line_no); switch (node->kind) { case ND_NULL_EXPR: return; case ND_NUM: { union { float f32; double f64; uint32_t u32; uint64_t u64; } u; switch (node->ty->kind) { case TY_FLOAT: u.f32 = node->fval; println(" mov $%u, %%eax # float %f", u.u32, node->fval); println(" movq %%rax, %%xmm0"); return; case TY_DOUBLE: u.f64 = node->fval; println(" mov $%lu, %%rax # double %f", u.u64, node->fval); println(" movq %%rax, %%xmm0"); return; } println(" mov $%ld, %%rax", node->val); return; } case ND_NEG: gen_expr(node->lhs); switch (node->ty->kind) { case TY_FLOAT: println(" mov $1, %%rax"); println(" shl $31, %%rax"); println(" movq %%rax, %%xmm1"); println(" xorps %%xmm1, %%xmm0"); return; case TY_DOUBLE: println(" mov $1, %%rax"); println(" shl $63, %%rax"); println(" movq %%rax, %%xmm1"); println(" xorpd %%xmm1, %%xmm0"); return; } println(" neg %%rax"); return; case ND_VAR: case ND_MEMBER: gen_addr(node); load(node->ty); return; case ND_DEREF: gen_expr(node->lhs); load(node->ty); return; case ND_ADDR: gen_addr(node->lhs); return; case ND_ASSIGN: gen_addr(node->lhs); push(); gen_expr(node->rhs); store(node->ty); return; case ND_STMT_EXPR: for (Node *n = node->body; n; n = n->next) gen_stmt(n); return; case ND_COMMA: gen_expr(node->lhs); gen_expr(node->rhs); return; case ND_CAST: gen_expr(node->lhs); cast(node->lhs->ty, node->ty); return; case ND_MEMZERO: // `rep stosb` is equivalent to `memset(%rdi, %al, %rcx)`. println(" mov $%d, %%rcx", node->var->ty->size); println(" lea %d(%%rbp), %%rdi", node->var->offset); println(" mov $0, %%al"); println(" rep stosb"); return; case ND_COND: { int c = count(); gen_expr(node->cond); cmp_zero(node->cond->ty); println(" je .L.else.%d", c); gen_expr(node->then); println(" jmp .L.end.%d", c); println(".L.else.%d:", c); gen_expr(node->els); println(".L.end.%d:", c); return; } case ND_NOT: gen_expr(node->lhs); cmp_zero(node->lhs->ty); println(" sete %%al"); println(" movzx %%al, %%rax"); return; case ND_BITNOT: gen_expr(node->lhs); println(" not %%rax"); return; case ND_LOGAND: { int c = count(); gen_expr(node->lhs); cmp_zero(node->lhs->ty); println(" je .L.false.%d", c); gen_expr(node->rhs); cmp_zero(node->rhs->ty); println(" je .L.false.%d", c); println(" mov $1, %%rax"); println(" jmp .L.end.%d", c); println(".L.false.%d:", c); println(" mov $0, %%rax"); println(".L.end.%d:", c); return; } case ND_LOGOR: { int c = count(); gen_expr(node->lhs); cmp_zero(node->lhs->ty); println(" jne .L.true.%d", c); gen_expr(node->rhs); cmp_zero(node->rhs->ty); println(" jne .L.true.%d", c); println(" mov $0, %%rax"); println(" jmp .L.end.%d", c); println(".L.true.%d:", c); println(" mov $1, %%rax"); println(".L.end.%d:", c); return; } case ND_FUNCALL: { push_args(node->args); gen_expr(node->lhs); int gp = 0, fp = 0; for (Node *arg = node->args; arg; arg = arg->next) { if (is_flonum(arg->ty)) popf(fp++); else pop(argreg64[gp++]); } if (depth % 2 == 0) { println(" call *%%rax"); } else { println(" sub $8, %%rsp"); println(" call *%%rax"); println(" add $8, %%rsp"); } // It looks like the most significant 48 or 56 bits in RAX may // contain garbage if a function return type is short or bool/char, // respectively. We clear the upper bits here. switch (node->ty->kind) { case TY_BOOL: println(" movzx %%al, %%eax"); return; case TY_CHAR: if (node->ty->is_unsigned) println(" movzbl %%al, %%eax"); else println(" movsbl %%al, %%eax"); return; case TY_SHORT: if (node->ty->is_unsigned) println(" movzwl %%ax, %%eax"); else println(" movswl %%ax, %%eax"); return; } return; } } if (is_flonum(node->lhs->ty)) { gen_expr(node->rhs); pushf(); gen_expr(node->lhs); popf(1); char *sz = (node->lhs->ty->kind == TY_FLOAT) ? "ss" : "sd"; switch (node->kind) { case ND_ADD: println(" add%s %%xmm1, %%xmm0", sz); return; case ND_SUB: println(" sub%s %%xmm1, %%xmm0", sz); return; case ND_MUL: println(" mul%s %%xmm1, %%xmm0", sz); return; case ND_DIV: println(" div%s %%xmm1, %%xmm0", sz); return; case ND_EQ: case ND_NE: case ND_LT: case ND_LE: println(" ucomi%s %%xmm0, %%xmm1", sz); if (node->kind == ND_EQ) { println(" sete %%al"); println(" setnp %%dl"); println(" and %%dl, %%al"); } else if (node->kind == ND_NE) { println(" setne %%al"); println(" setp %%dl"); println(" or %%dl, %%al"); } else if (node->kind == ND_LT) { println(" seta %%al"); } else { println(" setae %%al"); } println(" and $1, %%al"); println(" movzb %%al, %%rax"); return; } error_tok(node->tok, "invalid expression"); } gen_expr(node->rhs); push(); gen_expr(node->lhs); pop("%rdi"); char *ax, *di, *dx; if (node->lhs->ty->kind == TY_LONG || node->lhs->ty->base) { ax = "%rax"; di = "%rdi"; dx = "%rdx"; } else { ax = "%eax"; di = "%edi"; dx = "%edx"; } switch (node->kind) { case ND_ADD: println(" add %s, %s", di, ax); return; case ND_SUB: println(" sub %s, %s", di, ax); return; case ND_MUL: println(" imul %s, %s", di, ax); return; case ND_DIV: case ND_MOD: if (node->ty->is_unsigned) { println(" mov $0, %s", dx); println(" div %s", di); } else { if (node->lhs->ty->size == 8) println(" cqo"); else println(" cdq"); println(" idiv %s", di); } if (node->kind == ND_MOD) println(" mov %%rdx, %%rax"); return; case ND_BITAND: println(" and %s, %s", di, ax); return; case ND_BITOR: println(" or %s, %s", di, ax); return; case ND_BITXOR: println(" xor %s, %s", di, ax); return; case ND_EQ: case ND_NE: case ND_LT: case ND_LE: println(" cmp %s, %s", di, ax); if (node->kind == ND_EQ) { println(" sete %%al"); } else if (node->kind == ND_NE) { println(" setne %%al"); } else if (node->kind == ND_LT) { if (node->lhs->ty->is_unsigned) println(" setb %%al"); else println(" setl %%al"); } else if (node->kind == ND_LE) { if (node->lhs->ty->is_unsigned) println(" setbe %%al"); else println(" setle %%al"); } println(" movzb %%al, %%rax"); return; case ND_SHL: println(" mov %%rdi, %%rcx"); println(" shl %%cl, %s", ax); return; case ND_SHR: println(" mov %%rdi, %%rcx"); if (node->lhs->ty->is_unsigned) println(" shr %%cl, %s", ax); else println(" sar %%cl, %s", ax); return; } error_tok(node->tok, "invalid expression"); } static void gen_stmt(Node *node) { println(" .loc %d %d", node->tok->file->file_no, node->tok->line_no); switch (node->kind) { case ND_IF: { int c = count(); gen_expr(node->cond); cmp_zero(node->cond->ty); println(" je .L.else.%d", c); gen_stmt(node->then); println(" jmp .L.end.%d", c); println(".L.else.%d:", c); if (node->els) gen_stmt(node->els); println(".L.end.%d:", c); return; } case ND_FOR: { int c = count(); if (node->init) gen_stmt(node->init); println(".L.begin.%d:", c); if (node->cond) { gen_expr(node->cond); cmp_zero(node->cond->ty); println(" je %s", node->brk_label); } gen_stmt(node->then); println("%s:", node->cont_label); if (node->inc) gen_expr(node->inc); println(" jmp .L.begin.%d", c); println("%s:", node->brk_label); return; } case ND_DO: { int c = count(); println(".L.begin.%d:", c); gen_stmt(node->then); println("%s:", node->cont_label); gen_expr(node->cond); cmp_zero(node->cond->ty); println(" jne .L.begin.%d", c); println("%s:", node->brk_label); return; } case ND_SWITCH: gen_expr(node->cond); for (Node *n = node->case_next; n; n = n->case_next) { char *reg = (node->cond->ty->size == 8) ? "%rax" : "%eax"; println(" cmp $%ld, %s", n->val, reg); println(" je %s", n->label); } if (node->default_case) println(" jmp %s", node->default_case->label); println(" jmp %s", node->brk_label); gen_stmt(node->then); println("%s:", node->brk_label); return; case ND_CASE: println("%s:", node->label); gen_stmt(node->lhs); return; case ND_BLOCK: for (Node *n = node->body; n; n = n->next) gen_stmt(n); return; case ND_GOTO: println(" jmp %s", node->unique_label); return; case ND_LABEL: println("%s:", node->unique_label); gen_stmt(node->lhs); return; case ND_RETURN: if (node->lhs) gen_expr(node->lhs); println(" jmp .L.return.%s", current_fn->name); return; case ND_EXPR_STMT: gen_expr(node->lhs); return; } error_tok(node->tok, "invalid statement"); } // Assign offsets to local variables. static void assign_lvar_offsets(Obj *prog) { for (Obj *fn = prog; fn; fn = fn->next) { if (!fn->is_function) continue; int offset = 0; for (Obj *var = fn->locals; var; var = var->next) { offset += var->ty->size; offset = align_to(offset, var->align); var->offset = -offset; } fn->stack_size = align_to(offset, 16); } } static void emit_data(Obj *prog) { for (Obj *var = prog; var; var = var->next) { if (var->is_function || !var->is_definition) continue; if (var->is_static) println(" .local %s", var->name); else println(" .globl %s", var->name); println(" .align %d", var->align); if (var->init_data) { println(" .data"); println("%s:", var->name); Relocation *rel = var->rel; int pos = 0; while (pos < var->ty->size) { if (rel && rel->offset == pos) { println(" .quad %s%+ld", rel->label, rel->addend); rel = rel->next; pos += 8; } else { println(" .byte %d", var->init_data[pos++]); } } continue; } println(" .bss"); println("%s:", var->name); println(" .zero %d", var->ty->size); } } static void store_fp(int r, int offset, int sz) { switch (sz) { case 4: println(" movss %%xmm%d, %d(%%rbp)", r, offset); return; case 8: println(" movsd %%xmm%d, %d(%%rbp)", r, offset); return; } unreachable(); } static void store_gp(int r, int offset, int sz) { switch (sz) { case 1: println(" mov %s, %d(%%rbp)", argreg8[r], offset); return; case 2: println(" mov %s, %d(%%rbp)", argreg16[r], offset); return; case 4: println(" mov %s, %d(%%rbp)", argreg32[r], offset); return; case 8: println(" mov %s, %d(%%rbp)", argreg64[r], offset); return; } unreachable(); } static void emit_text(Obj *prog) { for (Obj *fn = prog; fn; fn = fn->next) { if (!fn->is_function || !fn->is_definition) continue; if (fn->is_static) println(" .local %s", fn->name); else println(" .globl %s", fn->name); println(" .text"); println("%s:", fn->name); current_fn = fn; // Prologue println(" push %%rbp"); println(" mov %%rsp, %%rbp"); println(" sub $%d, %%rsp", fn->stack_size); // Save arg registers if function is variadic if (fn->va_area) { int gp = 0, fp = 0; for (Obj *var = fn->params; var; var = var->next) { if (is_flonum(var->ty)) fp++; else gp++; } int off = fn->va_area->offset; // va_elem println(" movl $%d, %d(%%rbp)", gp * 8, off); println(" movl $%d, %d(%%rbp)", fp * 8 + 48, off + 4); println(" movq %%rbp, %d(%%rbp)", off + 16); println(" addq $%d, %d(%%rbp)", off + 24, off + 16); // __reg_save_area__ println(" movq %%rdi, %d(%%rbp)", off + 24); println(" movq %%rsi, %d(%%rbp)", off + 32); println(" movq %%rdx, %d(%%rbp)", off + 40); println(" movq %%rcx, %d(%%rbp)", off + 48); println(" movq %%r8, %d(%%rbp)", off + 56); println(" movq %%r9, %d(%%rbp)", off + 64); println(" movsd %%xmm0, %d(%%rbp)", off + 72); println(" movsd %%xmm1, %d(%%rbp)", off + 80); println(" movsd %%xmm2, %d(%%rbp)", off + 88); println(" movsd %%xmm3, %d(%%rbp)", off + 96); println(" movsd %%xmm4, %d(%%rbp)", off + 104); println(" movsd %%xmm5, %d(%%rbp)", off + 112); println(" movsd %%xmm6, %d(%%rbp)", off + 120); println(" movsd %%xmm7, %d(%%rbp)", off + 128); } // Save passed-by-register arguments to the stack int gp = 0, fp = 0; for (Obj *var = fn->params; var; var = var->next) { if (is_flonum(var->ty)) store_fp(fp++, var->offset, var->ty->size); else store_gp(gp++, var->offset, var->ty->size); } // Emit code gen_stmt(fn->body); assert(depth == 0); // Epilogue println(".L.return.%s:", fn->name); println(" mov %%rbp, %%rsp"); println(" pop %%rbp"); println(" ret"); } } void codegen(Obj *prog, FILE *out) { output_file = out; File **files = get_input_files(); for (int i = 0; files[i]; i++) println(" .file %d \"%s\"", files[i]->file_no, files[i]->name); assign_lvar_offsets(prog); emit_data(prog); emit_text(prog); }