From d7bad961146b9f2fd918f05fd59a50f3f65bf325 Mon Sep 17 00:00:00 2001 From: Rui Ueyama Date: Fri, 4 Sep 2020 11:57:15 +0900 Subject: [PATCH] Allow to define a function returning a struct --- codegen.c | 65 ++++++++++++++++++++++++++++++++++++++++++++++++- parse.c | 13 +++++++++- test/function.c | 43 ++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/codegen.c b/codegen.c index 31b4925..58b66c9 100644 --- a/codegen.c +++ b/codegen.c @@ -499,6 +499,59 @@ static void copy_ret_buffer(Obj *var) { } } +static void copy_struct_reg(void) { + Type *ty = current_fn->ty->return_ty; + int gp = 0, fp = 0; + + println(" mov %%rax, %%rdi"); + + if (has_flonum(ty, 0, 8, 0)) { + assert(ty->size == 4 || 8 <= ty->size); + if (ty->size == 4) + println(" movss (%%rdi), %%xmm0"); + else + println(" movsd (%%rdi), %%xmm0"); + fp++; + } else { + println(" mov $0, %%rax"); + for (int i = MIN(8, ty->size) - 1; i >= 0; i--) { + println(" shl $8, %%rax"); + println(" mov %d(%%rdi), %%al", i); + } + gp++; + } + + if (ty->size > 8) { + if (has_flonum(ty, 8, 16, 0)) { + assert(ty->size == 12 || ty->size == 16); + if (ty->size == 4) + println(" movss 8(%%rdi), %%xmm%d", fp); + else + println(" movsd 8(%%rdi), %%xmm%d", fp); + } else { + char *reg1 = (gp == 0) ? "%al" : "%dl"; + char *reg2 = (gp == 0) ? "%rax" : "%rdx"; + println(" mov $0, %s", reg2); + for (int i = MIN(16, ty->size) - 1; i >= 8; i--) { + println(" shl $8, %s", reg2); + println(" mov %d(%%rdi), %s", i, reg1); + } + } + } +} + +static void copy_struct_mem(void) { + Type *ty = current_fn->ty->return_ty; + Obj *var = current_fn->params; + + println(" mov %d(%%rbp), %%rdi", var->offset); + + for (int i = 0; i < ty->size; i++) { + println(" mov %d(%%rax), %%dl", i); + println(" mov %%dl, %d(%%rdi)", i); + } +} + // Generate code for a given node. static void gen_expr(Node *node) { println(" .loc %d %d", node->tok->file->file_no, node->tok->line_no); @@ -940,8 +993,18 @@ static void gen_stmt(Node *node) { gen_stmt(node->lhs); return; case ND_RETURN: - if (node->lhs) + if (node->lhs) { gen_expr(node->lhs); + + Type *ty = node->lhs->ty; + if (ty->kind == TY_STRUCT || ty->kind == TY_UNION) { + if (ty->size <= 16) + copy_struct_reg(); + else + copy_struct_mem(); + } + } + println(" jmp .L.return.%s", current_fn->name); return; case ND_EXPR_STMT: diff --git a/parse.c b/parse.c index 94ff997..4cda47d 100644 --- a/parse.c +++ b/parse.c @@ -1188,7 +1188,11 @@ static Node *stmt(Token **rest, Token *tok) { *rest = skip(tok, ";"); add_type(exp); - node->lhs = new_cast(exp, current_fn->ty->return_ty); + Type *ty = current_fn->ty->return_ty; + if (ty->kind != TY_STRUCT && ty->kind != TY_UNION) + exp = new_cast(exp, current_fn->ty->return_ty); + + node->lhs = exp; return node; } @@ -2410,6 +2414,13 @@ static Token *function(Token *tok, Type *basety, VarAttr *attr) { locals = NULL; enter_scope(); create_param_lvars(ty->params); + + // A buffer for a struct/union return value is passed + // as the hidden first parameter. + Type *rty = ty->return_ty; + if ((rty->kind == TY_STRUCT || rty->kind == TY_UNION) && rty->size > 16) + new_lvar("", pointer_to(rty)); + fn->params = locals; if (ty->is_variadic) fn->va_area = new_lvar("__va_area__", array_of(ty_char, 136)); diff --git a/test/function.c b/test/function.c index 2f6f00a..6160c36 100644 --- a/test/function.c +++ b/test/function.c @@ -177,6 +177,26 @@ Ty6 struct_test26(void); Ty20 struct_test27(void); Ty21 struct_test28(void); +Ty4 struct_test34(void) { + return (Ty4){10, 20, 30, 40}; +} + +Ty5 struct_test35(void) { + return (Ty5){10, 20, 30}; +} + +Ty6 struct_test36(void) { + return (Ty6){10, 20, 30}; +} + +Ty20 struct_test37(void) { + return (Ty20){10, 20, 30, 40, 50, 60, 70, 80, 90, 100}; +} + +Ty21 struct_test38(void) { + return (Ty21){1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20}; +} + int main() { ASSERT(3, ret3()); ASSERT(8, add2(3, 5)); @@ -320,6 +340,29 @@ int main() { ASSERT(15, struct_test28().a[14]); ASSERT(20, struct_test28().a[19]); + ASSERT(10, struct_test34().a); + ASSERT(20, struct_test34().b); + ASSERT(30, struct_test34().c); + ASSERT(40, struct_test34().d); + + ASSERT(10, struct_test35().a); + ASSERT(20, struct_test35().b); + ASSERT(30, struct_test35().c); + + ASSERT(10, struct_test36().a[0]); + ASSERT(20, struct_test36().a[1]); + ASSERT(30, struct_test36().a[2]); + + ASSERT(10, struct_test37().a[0]); + ASSERT(60, struct_test37().a[5]); + ASSERT(100, struct_test37().a[9]); + + ASSERT(1, struct_test38().a[0]); + ASSERT(5, struct_test38().a[4]); + ASSERT(10, struct_test38().a[9]); + ASSERT(15, struct_test38().a[14]); + ASSERT(20, struct_test38().a[19]); + printf("OK\n"); return 0; }