From ca27455b92be2ffbfe58c7ffda623cf6ec112632 Mon Sep 17 00:00:00 2001 From: Rui Ueyama Date: Tue, 15 Sep 2020 11:30:56 +0900 Subject: [PATCH] Add atomic_compare_exchange --- chibicc.h | 6 ++++++ codegen.c | 40 ++++++++++++++++++++++++++++++++++++++++ include/stdatomic.h | 10 ++++++++++ parse.c | 12 ++++++++++++ test/atomic.c | 43 +++++++++++++++++++++++++++++++++++++++++++ type.c | 11 +++++++++++ 6 files changed, 122 insertions(+) create mode 100644 include/stdatomic.h create mode 100644 test/atomic.c diff --git a/chibicc.h b/chibicc.h index 65484ec..6f19a5e 100644 --- a/chibicc.h +++ b/chibicc.h @@ -220,6 +220,7 @@ typedef enum { ND_CAST, // Type cast ND_MEMZERO, // Zero-clear a stack variable ND_ASM, // "asm" + ND_CAS, // Atomic compare-and-swap } NodeKind; // AST node type @@ -271,6 +272,11 @@ struct Node { // "asm" string literal char *asm_str; + // Atomic compare-and-swap + Node *cas_addr; + Node *cas_old; + Node *cas_new; + // Variable Obj *var; diff --git a/codegen.c b/codegen.c index 85e15a3..ac296c0 100644 --- a/codegen.c +++ b/codegen.c @@ -56,6 +56,26 @@ int align_to(int n, int align) { return (n + align - 1) / align * align; } +static char *reg_dx(int sz) { + switch (sz) { + case 1: return "%dl"; + case 2: return "%dx"; + case 4: return "%edx"; + case 8: return "%rdx"; + } + unreachable(); +} + +static char *reg_ax(int sz) { + switch (sz) { + case 1: return "%al"; + case 2: return "%ax"; + case 4: return "%eax"; + case 8: return "%rax"; + } + unreachable(); +} + // Compute the absolute address of a given node. // It's an error if a given node does not reside in memory. static void gen_addr(Node *node) { @@ -943,6 +963,26 @@ static void gen_expr(Node *node) { case ND_LABEL_VAL: println(" lea %s(%%rip), %%rax", node->unique_label); return; + case ND_CAS: { + gen_expr(node->cas_addr); + push(); + gen_expr(node->cas_new); + push(); + gen_expr(node->cas_old); + println(" mov %%rax, %%r8"); + load(node->cas_old->ty->base); + pop("%rdx"); // new + pop("%rdi"); // addr + + int sz = node->cas_addr->ty->base->size; + println(" lock cmpxchg %s, (%%rdi)", reg_dx(sz)); + println(" sete %%cl"); + println(" je 1f"); + println(" mov %s, (%%r8)", reg_ax(sz)); + println("1:"); + println(" movzbl %%cl, %%eax"); + return; + } } switch (node->lhs->ty->kind) { diff --git a/include/stdatomic.h b/include/stdatomic.h new file mode 100644 index 0000000..238f10b --- /dev/null +++ b/include/stdatomic.h @@ -0,0 +1,10 @@ +#ifndef __STDATOMIC_H +#define __STDATOMIC_H + +#define atomic_compare_exchange_weak(p, old, new) \ + __builtin_compare_and_swap((p), (old), (new)) + +#define atomic_compare_exchange_strong(p, old, new) \ + __builtin_compare_and_swap((p), (old), (new)) + +#endif diff --git a/parse.c b/parse.c index c0505ff..df86076 100644 --- a/parse.c +++ b/parse.c @@ -2942,6 +2942,18 @@ static Node *primary(Token **rest, Token *tok) { return new_num(2, start); } + if (equal(tok, "__builtin_compare_and_swap")) { + Node *node = new_node(ND_CAS, tok); + tok = skip(tok->next, "("); + node->cas_addr = assign(&tok, tok); + tok = skip(tok, ","); + node->cas_old = assign(&tok, tok); + tok = skip(tok, ","); + node->cas_new = assign(&tok, tok); + *rest = skip(tok, ")"); + return node; + } + if (tok->kind == TK_IDENT) { // Variable or enum constant VarScope *sc = find_var(tok); diff --git a/test/atomic.c b/test/atomic.c new file mode 100644 index 0000000..748db27 --- /dev/null +++ b/test/atomic.c @@ -0,0 +1,43 @@ +#include "test.h" +#include +#include + +static int incr(int *p) { + int oldval = *p; + int newval; + do { + newval = oldval + 1; + } while (!atomic_compare_exchange_weak(p, &oldval, newval)); + return newval; +} + +static int add(void *arg) { + int *x = arg; + for (int i = 0; i < 1000*1000; i++) + incr(x); + return 0; +} + +static int add_millions(void) { + int x = 0; + + pthread_t thr1; + pthread_t thr2; + + pthread_create(&thr1, NULL, add, &x); + pthread_create(&thr2, NULL, add, &x); + + for (int i = 0; i < 1000*1000; i++) + incr(&x); + + pthread_join(thr1, NULL); + pthread_join(thr2, NULL); + return x; +} + +int main() { + ASSERT(3*1000*1000, add_millions()); + + printf("OK\n"); + return 0; +} diff --git a/type.c b/type.c index 1d4400e..440d4aa 100644 --- a/type.c +++ b/type.c @@ -287,5 +287,16 @@ void add_type(Node *node) { case ND_LABEL_VAL: node->ty = pointer_to(ty_void); return; + case ND_CAS: + add_type(node->cas_addr); + add_type(node->cas_old); + add_type(node->cas_new); + node->ty = ty_bool; + + if (node->cas_addr->ty->kind != TY_PTR) + error_tok(node->cas_addr->tok, "pointer expected"); + if (node->cas_old->ty->kind != TY_PTR) + error_tok(node->cas_old->tok, "pointer expected"); + return; } }