diff --git a/chibicc.h b/chibicc.h index 30d07cb..c580b6d 100644 --- a/chibicc.h +++ b/chibicc.h @@ -75,7 +75,7 @@ struct Token { TokenKind kind; // Token kind Token *next; // Next token int64_t val; // If kind is TK_NUM, its value - double fval; // If kind is TK_NUM, its value + long double fval; // If kind is TK_NUM, its value char *loc; // Token location int len; // Token length Type *ty; // Used if TK_NUM or TK_STR @@ -271,7 +271,7 @@ struct Node { // Numeric literal int64_t val; - double fval; + long double fval; }; Node *new_cast(Node *expr, Type *ty); @@ -291,6 +291,7 @@ typedef enum { TY_LONG, TY_FLOAT, TY_DOUBLE, + TY_LDOUBLE, TY_ENUM, TY_PTR, TY_FUNC, @@ -370,6 +371,7 @@ extern Type *ty_ulong; extern Type *ty_float; extern Type *ty_double; +extern Type *ty_ldouble; bool is_integer(Type *ty); bool is_flonum(Type *ty); diff --git a/codegen.c b/codegen.c index df00eb2..ce38b04 100644 --- a/codegen.c +++ b/codegen.c @@ -161,6 +161,9 @@ static void load(Type *ty) { case TY_DOUBLE: println(" movsd (%%rax), %%xmm0"); return; + case TY_LDOUBLE: + println(" fldt (%%rax)"); + return; } char *insn = ty->is_unsigned ? "movz" : "movs"; @@ -198,6 +201,9 @@ static void store(Type *ty) { case TY_DOUBLE: println(" movsd %%xmm0, (%%rdi)"); return; + case TY_LDOUBLE: + println(" fstpt (%%rdi)"); + return; } if (ty->size == 1) @@ -220,6 +226,11 @@ static void cmp_zero(Type *ty) { println(" xorpd %%xmm1, %%xmm1"); println(" ucomisd %%xmm1, %%xmm0"); return; + case TY_LDOUBLE: + println(" fldz"); + println(" fucomip"); + println(" fstp %%st(0)"); + return; } if (is_integer(ty) && ty->size <= 4) @@ -228,7 +239,7 @@ static void cmp_zero(Type *ty) { println(" cmp $0, %%rax"); } -enum { I8, I16, I32, I64, U8, U16, U32, U64, F32, F64 }; +enum { I8, I16, I32, I64, U8, U16, U32, U64, F32, F64, F80 }; static int getTypeId(Type *ty) { switch (ty->kind) { @@ -244,6 +255,8 @@ static int getTypeId(Type *ty) { return F32; case TY_DOUBLE: return F64; + case TY_LDOUBLE: + return F80; } return U64; } @@ -256,19 +269,25 @@ static char i32u16[] = "movzwl %ax, %eax"; static char i32f32[] = "cvtsi2ssl %eax, %xmm0"; static char i32i64[] = "movsxd %eax, %rax"; static char i32f64[] = "cvtsi2sdl %eax, %xmm0"; +static char i32f80[] = "mov %eax, -4(%rsp); fildl -4(%rsp)"; static char u32f32[] = "mov %eax, %eax; cvtsi2ssq %rax, %xmm0"; static char u32i64[] = "mov %eax, %eax"; static char u32f64[] = "mov %eax, %eax; cvtsi2sdq %rax, %xmm0"; +static char u32f80[] = "mov %eax, %eax; mov %rax, -8(%rsp); fildll -8(%rsp)"; static char i64f32[] = "cvtsi2ssq %rax, %xmm0"; static char i64f64[] = "cvtsi2sdq %rax, %xmm0"; +static char i64f80[] = "movq %rax, -8(%rsp); fildll -8(%rsp)"; static char u64f32[] = "cvtsi2ssq %rax, %xmm0"; static char u64f64[] = "test %rax,%rax; js 1f; pxor %xmm0,%xmm0; cvtsi2sd %rax,%xmm0; jmp 2f; " "1: mov %rax,%rdi; and $1,%eax; pxor %xmm0,%xmm0; shr %rdi; " "or %rax,%rdi; cvtsi2sd %rdi,%xmm0; addsd %xmm0,%xmm0; 2:"; +static char u64f80[] = + "mov %rax, -8(%rsp); fildq -8(%rsp); test %rax, %rax; jns 1f;" + "mov $1602224128, %eax; mov %eax, -4(%rsp); fadds -4(%rsp); 1:"; static char f32i8[] = "cvttss2sil %xmm0, %eax; movsbl %al, %eax"; static char f32u8[] = "cvttss2sil %xmm0, %eax; movzbl %al, %eax"; @@ -279,6 +298,7 @@ static char f32u32[] = "cvttss2siq %xmm0, %rax"; static char f32i64[] = "cvttss2siq %xmm0, %rax"; static char f32u64[] = "cvttss2siq %xmm0, %rax"; static char f32f64[] = "cvtss2sd %xmm0, %xmm0"; +static char f32f80[] = "movss %xmm0, -4(%rsp); flds -4(%rsp)"; static char f64i8[] = "cvttsd2sil %xmm0, %eax; movsbl %al, %eax"; static char f64u8[] = "cvttsd2sil %xmm0, %eax; movzbl %al, %eax"; @@ -286,24 +306,43 @@ static char f64i16[] = "cvttsd2sil %xmm0, %eax; movswl %ax, %eax"; static char f64u16[] = "cvttsd2sil %xmm0, %eax; movzwl %ax, %eax"; static char f64i32[] = "cvttsd2sil %xmm0, %eax"; static char f64u32[] = "cvttsd2siq %xmm0, %rax"; -static char f64f32[] = "cvtsd2ss %xmm0, %xmm0"; static char f64i64[] = "cvttsd2siq %xmm0, %rax"; static char f64u64[] = "cvttsd2siq %xmm0, %rax"; +static char f64f32[] = "cvtsd2ss %xmm0, %xmm0"; +static char f64f80[] = "movsd %xmm0, -8(%rsp); fldl -8(%rsp)"; -static char *cast_table[][10] = { - // i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 - {NULL, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i8 - {i32i8, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i16 - {i32i8, i32i16, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64}, // i32 - {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, i64f32, i64f64}, // i64 +#define FROM_F80_1 \ + "fnstcw -10(%rsp); movzwl -10(%rsp), %eax; or $12, %ah; " \ + "mov %ax, -12(%rsp); fldcw -12(%rsp); " - {i32i8, NULL, NULL, i32i64, NULL, NULL, NULL, i32i64, i32f32, i32f64}, // u8 - {i32i8, i32i16, NULL, i32i64, i32u8, NULL, NULL, i32i64, i32f32, i32f64}, // u16 - {i32i8, i32i16, NULL, u32i64, i32u8, i32u16, NULL, u32i64, u32f32, u32f64}, // u32 - {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, u64f32, u64f64}, // u64 +#define FROM_F80_2 " -24(%rsp); fldcw -10(%rsp); " - {f32i8, f32i16, f32i32, f32i64, f32u8, f32u16, f32u32, f32u64, NULL, f32f64}, // f32 - {f64i8, f64i16, f64i32, f64i64, f64u8, f64u16, f64u32, f64u64, f64f32, NULL}, // f64 +static char f80i8[] = FROM_F80_1 "fistps" FROM_F80_2 "movsbl -24(%rsp), %eax"; +static char f80u8[] = FROM_F80_1 "fistps" FROM_F80_2 "movzbl -24(%rsp), %eax"; +static char f80i16[] = FROM_F80_1 "fistps" FROM_F80_2 "movzbl -24(%rsp), %eax"; +static char f80u16[] = FROM_F80_1 "fistpl" FROM_F80_2 "movswl -24(%rsp), %eax"; +static char f80i32[] = FROM_F80_1 "fistpl" FROM_F80_2 "mov -24(%rsp), %eax"; +static char f80u32[] = FROM_F80_1 "fistpl" FROM_F80_2 "mov -24(%rsp), %eax"; +static char f80i64[] = FROM_F80_1 "fistpq" FROM_F80_2 "mov -24(%rsp), %rax"; +static char f80u64[] = FROM_F80_1 "fistpq" FROM_F80_2 "mov -24(%rsp), %rax"; +static char f80f32[] = "fstps -8(%rsp); movss -8(%rsp), %xmm0"; +static char f80f64[] = "fstpl -8(%rsp); movsd -8(%rsp), %xmm0"; + +static char *cast_table[][11] = { + // i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 f80 + {NULL, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64, i32f80}, // i8 + {i32i8, NULL, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64, i32f80}, // i16 + {i32i8, i32i16, NULL, i32i64, i32u8, i32u16, NULL, i32i64, i32f32, i32f64, i32f80}, // i32 + {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, i64f32, i64f64, i64f80}, // i64 + + {i32i8, NULL, NULL, i32i64, NULL, NULL, NULL, i32i64, i32f32, i32f64, i32f80}, // u8 + {i32i8, i32i16, NULL, i32i64, i32u8, NULL, NULL, i32i64, i32f32, i32f64, i32f80}, // u16 + {i32i8, i32i16, NULL, u32i64, i32u8, i32u16, NULL, u32i64, u32f32, u32f64, u32f80}, // u32 + {i32i8, i32i16, NULL, NULL, i32u8, i32u16, NULL, NULL, u64f32, u64f64, u64f80}, // u64 + + {f32i8, f32i16, f32i32, f32i64, f32u8, f32u16, f32u32, f32u64, NULL, f32f64, f32f80}, // f32 + {f64i8, f64i16, f64i32, f64i64, f64u8, f64u16, f64u32, f64u64, f64f32, NULL, f64f80}, // f64 + {f80i8, f80i16, f80i32, f80i64, f80u8, f80u16, f80u32, f80u64, f80f32, f80f64, NULL}, // f80 }; static void cast(Type *from, Type *to) { @@ -350,7 +389,7 @@ static bool has_flonum(Type *ty, int lo, int hi, int offset) { return true; } - return offset < lo || hi <= offset || is_flonum(ty); + return offset < lo || hi <= offset || ty->kind == TY_FLOAT || ty->kind == TY_DOUBLE; } static bool has_flonum1(Type *ty) { @@ -391,6 +430,11 @@ static void push_args2(Node *args, bool first_pass) { case TY_DOUBLE: pushf(); break; + case TY_LDOUBLE: + println(" sub $16, %%rsp"); + println(" fstpt (%%rsp)"); + depth += 2; + break; default: push(); } @@ -453,6 +497,10 @@ static int push_args(Node *node) { stack++; } break; + case TY_LDOUBLE: + arg->pass_by_stack = true; + stack += 2; + break; default: if (gp++ >= GP_MAX) { arg->pass_by_stack = true; @@ -606,20 +654,31 @@ static void gen_expr(Node *node) { case ND_NULL_EXPR: return; case ND_NUM: { - union { float f32; double f64; uint32_t u32; uint64_t u64; } u; - switch (node->ty->kind) { - case TY_FLOAT: - u.f32 = node->fval; - println(" mov $%u, %%eax # float %f", u.u32, node->fval); + case TY_FLOAT: { + union { float f32; uint32_t u32; } u = { node->fval }; + println(" mov $%u, %%eax # float %Lf", u.u32, node->fval); println(" movq %%rax, %%xmm0"); return; - case TY_DOUBLE: - u.f64 = node->fval; - println(" mov $%lu, %%rax # double %f", u.u64, node->fval); + } + case TY_DOUBLE: { + union { double f64; uint64_t u64; } u = { node->fval }; + println(" mov $%lu, %%rax # double %Lf", u.u64, node->fval); println(" movq %%rax, %%xmm0"); return; } + case TY_LDOUBLE: { + union { long double f80; uint64_t u64[2]; } u; + memset(&u, 0, sizeof(u)); + u.f80 = node->fval; + println(" mov $%lu, %%rax # long double %Lf", u.u64[0], node->fval); + println(" mov %%rax, -16(%%rsp)"); + println(" mov $%lu, %%rax", u.u64[1]); + println(" mov %%rax, -8(%%rsp)"); + println(" fldt -16(%%rsp)"); + return; + } + } println(" mov $%ld, %%rax", node->val); return; @@ -640,6 +699,9 @@ static void gen_expr(Node *node) { println(" movq %%rax, %%xmm1"); println(" xorpd %%xmm1, %%xmm0"); return; + case TY_LDOUBLE: + println(" fchs"); + return; } println(" neg %%rax"); @@ -818,6 +880,8 @@ static void gen_expr(Node *node) { if (fp < FP_MAX) popf(fp++); break; + case TY_LDOUBLE: + break; default: if (gp < GP_MAX) pop(argreg64[gp++]); @@ -863,7 +927,9 @@ static void gen_expr(Node *node) { } } - if (is_flonum(node->lhs->ty)) { + switch (node->lhs->ty->kind) { + case TY_FLOAT: + case TY_DOUBLE: { gen_expr(node->rhs); pushf(); gen_expr(node->lhs); @@ -911,6 +977,46 @@ static void gen_expr(Node *node) { error_tok(node->tok, "invalid expression"); } + case TY_LDOUBLE: { + gen_expr(node->lhs); + gen_expr(node->rhs); + + switch (node->kind) { + case ND_ADD: + println(" faddp"); + return; + case ND_SUB: + println(" fsubrp"); + return; + case ND_MUL: + println(" fmulp"); + return; + case ND_DIV: + println(" fdivrp"); + return; + case ND_EQ: + case ND_NE: + case ND_LT: + case ND_LE: + println(" fcomip"); + println(" fstp %%st(0)"); + + if (node->kind == ND_EQ) + println(" sete %%al"); + else if (node->kind == ND_NE) + println(" setne %%al"); + else if (node->kind == ND_LT) + println(" seta %%al"); + else + println(" setae %%al"); + + println(" movzb %%al, %%rax"); + return; + } + + error_tok(node->tok, "invalid expression"); + } + } gen_expr(node->rhs); push(); @@ -1084,13 +1190,16 @@ static void gen_stmt(Node *node) { case ND_RETURN: if (node->lhs) { gen_expr(node->lhs); - Type *ty = node->lhs->ty; - if (ty->kind == TY_STRUCT || ty->kind == TY_UNION) { + + switch (ty->kind) { + case TY_STRUCT: + case TY_UNION: if (ty->size <= 16) copy_struct_reg(); else copy_struct_mem(); + break; } } @@ -1143,6 +1252,8 @@ static void assign_lvar_offsets(Obj *prog) { if (fp++ < FP_MAX) continue; break; + case TY_LDOUBLE: + break; default: if (gp++ < GP_MAX) continue; diff --git a/parse.c b/parse.c index 1aae37a..792072c 100644 --- a/parse.c +++ b/parse.c @@ -562,9 +562,11 @@ static Type *declspec(Token **rest, Token *tok, VarAttr *attr) { ty = ty_float; break; case DOUBLE: - case LONG + DOUBLE: ty = ty_double; break; + case LONG + DOUBLE: + ty = ty_ldouble; + break; default: error_tok(tok, "invalid type"); } diff --git a/test/arith.c b/test/arith.c index f9988af..f98b680 100644 --- a/test/arith.c +++ b/test/arith.c @@ -134,6 +134,11 @@ int main() { ASSERT(5, 0?:5); ASSERT(4, ({ int i = 3; ++i?:10; })); + ASSERT(3, (long double)3); + ASSERT(5, (long double)3+2); + ASSERT(6, (long double)3*2); + ASSERT(5, (long double)3+2.0); + printf("OK\n"); return 0; } diff --git a/test/function.c b/test/function.c index 586f10c..eb6ad3a 100644 --- a/test/function.c +++ b/test/function.c @@ -201,6 +201,14 @@ inline int inline_fn(void) { return 3; } +double to_double(long double x) { + return x; +} + +long double to_ldouble(int x) { + return x; +} + int main() { ASSERT(3, ret3()); ASSERT(8, add2(3, 5)); @@ -371,5 +379,16 @@ int main() { ASSERT(3, inline_fn()); + ASSERT(0, ({ char buf[100]; sprintf(buf, "%Lf", (long double)12.3); strncmp(buf, "12.3", 4); })); + + ASSERT(1, to_double(3.5) == 3.5); + ASSERT(0, to_double(3.5) == 3); + + ASSERT(1, (long double)5.0 == (long double)5.0); + ASSERT(0, (long double)5.0 == (long double)5.2); + + ASSERT(1, to_ldouble(5.0) == 5.0); + ASSERT(0, to_ldouble(5.0) == 5.2); + printf("OK\n"); } diff --git a/test/literal.c b/test/literal.c index 472b72d..45d06c9 100644 --- a/test/literal.c +++ b/test/literal.c @@ -88,8 +88,8 @@ int main() { ASSERT(4, sizeof(0.3F)); ASSERT(8, sizeof(0.)); ASSERT(8, sizeof(.0)); - ASSERT(8, sizeof(5.l)); - ASSERT(8, sizeof(2.0L)); + ASSERT(16, sizeof(5.l)); + ASSERT(16, sizeof(2.0L)); assert(1, size\ of(char), \ diff --git a/test/sizeof.c b/test/sizeof.c index 07b3118..5b4398d 100644 --- a/test/sizeof.c +++ b/test/sizeof.c @@ -96,7 +96,7 @@ int main() { ASSERT(4, sizeof(1f/2)); ASSERT(8, sizeof(1.0/2)); - ASSERT(8, sizeof(long double)); + ASSERT(16, sizeof(long double)); ASSERT(1, sizeof(main)); diff --git a/tokenize.c b/tokenize.c index f3671f4..7383dcf 100644 --- a/tokenize.c +++ b/tokenize.c @@ -426,14 +426,14 @@ static void convert_pp_number(Token *tok) { // If it's not an integer, it must be a floating point constant. char *end; - double val = strtod(tok->loc, &end); + long double val = strtold(tok->loc, &end); Type *ty; if (*end == 'f' || *end == 'F') { ty = ty_float; end++; } else if (*end == 'l' || *end == 'L') { - ty = ty_double; + ty = ty_ldouble; end++; } else { ty = ty_double; diff --git a/type.c b/type.c index 781ac48..6c1a6d4 100644 --- a/type.c +++ b/type.c @@ -15,6 +15,7 @@ Type *ty_ulong = &(Type){TY_LONG, 8, 8, true}; Type *ty_float = &(Type){TY_FLOAT, 4, 4}; Type *ty_double = &(Type){TY_DOUBLE, 8, 8}; +Type *ty_ldouble = &(Type){TY_LDOUBLE, 16, 16}; static Type *new_type(TypeKind kind, int size, int align) { Type *ty = calloc(1, sizeof(Type)); @@ -31,7 +32,8 @@ bool is_integer(Type *ty) { } bool is_flonum(Type *ty) { - return ty->kind == TY_FLOAT || ty->kind == TY_DOUBLE; + return ty->kind == TY_FLOAT || ty->kind == TY_DOUBLE || + ty->kind == TY_LDOUBLE; } bool is_numeric(Type *ty) { @@ -59,6 +61,7 @@ bool is_compatible(Type *t1, Type *t2) { return t1->is_unsigned == t2->is_unsigned; case TY_FLOAT: case TY_DOUBLE: + case TY_LDOUBLE: return true; case TY_PTR: return is_compatible(t1->base, t2->base); @@ -137,6 +140,8 @@ static Type *get_common_type(Type *ty1, Type *ty2) { if (ty2->kind == TY_FUNC) return pointer_to(ty2); + if (ty1->kind == TY_LDOUBLE || ty2->kind == TY_LDOUBLE) + return ty_ldouble; if (ty1->kind == TY_DOUBLE || ty2->kind == TY_DOUBLE) return ty_double; if (ty1->kind == TY_FLOAT || ty2->kind == TY_FLOAT)