diff --git a/py/emitbc.c b/py/emitbc.c index 8680f49597..269fcdeb7e 100644 --- a/py/emitbc.c +++ b/py/emitbc.c @@ -242,7 +242,14 @@ STATIC void emit_bc_end_pass(emit_t *emit) { emit->code_base = m_new(byte, emit->code_info_size + emit->byte_code_size); } else if (emit->pass == PASS_3) { - rt_assign_byte_code(emit->scope->unique_code_id, emit->code_base, emit->code_info_size + emit->byte_code_size, emit->scope->num_params, emit->scope->num_locals, emit->scope->stack_size, emit->scope->scope_flags); + qstr *arg_names = m_new(qstr, emit->scope->num_params); + for (int i = 0; i < emit->scope->num_params; i++) { + arg_names[i] = emit->scope->id_info[i].qstr; + } + rt_assign_byte_code(emit->scope->unique_code_id, emit->code_base, + emit->code_info_size + emit->byte_code_size, + emit->scope->num_params, emit->scope->num_locals, emit->scope->stack_size, + emit->scope->scope_flags, arg_names); } } diff --git a/py/obj.h b/py/obj.h index 55be768730..c2b127c328 100644 --- a/py/obj.h +++ b/py/obj.h @@ -236,7 +236,7 @@ mp_obj_t mp_obj_new_exception_msg(const mp_obj_type_t *exc_type, const char *msg mp_obj_t mp_obj_new_exception_msg_varg(const mp_obj_type_t *exc_type, const char *fmt, ...); // counts args by number of % symbols in fmt, excluding %%; can only handle void* sizes (ie no float/double!) mp_obj_t mp_obj_new_range(int start, int stop, int step); mp_obj_t mp_obj_new_range_iterator(int cur, int stop, int step); -mp_obj_t mp_obj_new_fun_bc(uint scope_flags, uint n_args, mp_obj_t def_args, uint n_state, const byte *code); +mp_obj_t mp_obj_new_fun_bc(uint scope_flags, qstr *args, uint n_args, mp_obj_t def_args, uint n_state, const byte *code); mp_obj_t mp_obj_new_fun_asm(uint n_args, void *fun); mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun); mp_obj_t mp_obj_new_gen_instance(const byte *bytecode, uint n_state, int n_args, const mp_obj_t *args); diff --git a/py/objfun.c b/py/objfun.c index b2837be5f0..354d7ff9ca 100644 --- a/py/objfun.c +++ b/py/objfun.c @@ -14,6 +14,12 @@ #include "runtime.h" #include "bc.h" +#if 0 // print debugging info +#define DEBUG_PRINT (1) +#else // don't print debugging info +#define DEBUG_printf(args...) (void)0 +#endif + /******************************************************************************/ /* native functions */ @@ -141,16 +147,30 @@ typedef struct _mp_obj_fun_bc_t { }; uint n_state; // total state size for the executing function (incl args, locals, stack) const byte *bytecode; // bytecode for the function + qstr *args; // argument names (needed to resolve positional args passed as keywords) mp_obj_t extra_args[]; // values of default args (if any), plus a slot at the end for var args and/or kw args (if it takes them) } mp_obj_fun_bc_t; +void dump_args(const mp_obj_t *a, int sz) { +#if DEBUG_PRINT + DEBUG_printf("%p: ", a); + for (int i = 0; i < sz; i++) { + DEBUG_printf("%p ", a[i]); + } + DEBUG_printf("\n"); +#endif +} + STATIC mp_obj_t fun_bc_call(mp_obj_t self_in, uint n_args, uint n_kw, const mp_obj_t *args) { + DEBUG_printf("Input: "); + dump_args(args, n_args); mp_obj_fun_bc_t *self = self_in; const mp_obj_t *kwargs = args + n_args; mp_obj_t *extra_args = self->extra_args + self->n_def_args; uint n_extra_args = 0; + // check positional arguments if (n_args > self->n_args) { @@ -162,31 +182,93 @@ STATIC mp_obj_t fun_bc_call(mp_obj_t self_in, uint n_args, uint n_kw, const mp_o *extra_args = mp_obj_new_tuple(n_args - self->n_args, args + self->n_args); n_extra_args = 1; n_args = self->n_args; - } else if (n_args >= self->n_args - self->n_def_args) { - // given enough arguments, but may need to use some default arguments + } else { if (self->takes_var_args) { + DEBUG_printf("passing empty tuple as *args\n"); *extra_args = mp_const_empty_tuple; n_extra_args = 1; } - extra_args -= self->n_args - n_args; - n_extra_args += self->n_args - n_args; - } else { - goto arg_error; + // Apply processing and check below only if we don't have kwargs, + // otherwise, kw handling code below has own extensive checks. + if (n_kw == 0) { + if (n_args >= self->n_args - self->n_def_args) { + // given enough arguments, but may need to use some default arguments + extra_args -= self->n_args - n_args; + n_extra_args += self->n_args - n_args; + } else { + goto arg_error; + } + } } // check keyword arguments if (n_kw != 0) { - // keyword arguments given - if (!self->takes_kw_args) { - nlr_jump(mp_obj_new_exception_msg(&mp_type_TypeError, "function does not take keyword arguments")); + // We cannot use dynamically-sized array here, because GCC indeed + // deallocates it on leaving defining scope (unlike most static stack allocs). + // So, we have 2 choices: allocate it unconditionally at the top of function + // (wastes stack), or use alloca which is guaranteed to dealloc on func exit. + //mp_obj_t flat_args[self->n_args]; + mp_obj_t *flat_args = alloca(self->n_args * sizeof(mp_obj_t)); + for (int i = self->n_args - 1; i >= 0; i--) { + flat_args[i] = MP_OBJ_NULL; + } + memcpy(flat_args, args, sizeof(*args) * n_args); + DEBUG_printf("Initial args: "); + dump_args(flat_args, self->n_args); + + mp_obj_t dict = MP_OBJ_NULL; + if (self->takes_kw_args) { + dict = mp_obj_new_dict(n_kw); // TODO: better go conservative with 0? } - mp_obj_t dict = mp_obj_new_dict(n_kw); for (uint i = 0; i < n_kw; i++) { + qstr arg_name = MP_OBJ_QSTR_VALUE(kwargs[2 * i]); + for (uint j = 0; j < self->n_args; j++) { + if (arg_name == self->args[j]) { + if (flat_args[j] != MP_OBJ_NULL) { + nlr_jump(mp_obj_new_exception_msg_varg(&mp_type_TypeError, + "function got multiple values for argument '%s'", qstr_str(arg_name))); + } + flat_args[j] = kwargs[2 * i + 1]; + goto continue2; + } + } + // Didn't find name match with positional args + if (!self->takes_kw_args) { + nlr_jump(mp_obj_new_exception_msg(&mp_type_TypeError, "function does not take keyword arguments")); + } mp_obj_dict_store(dict, kwargs[2 * i], kwargs[2 * i + 1]); +continue2:; + } + DEBUG_printf("Args with kws flattened: "); + dump_args(flat_args, self->n_args); + + // Now fill in defaults + mp_obj_t *d = &flat_args[self->n_args - 1]; + mp_obj_t *s = &self->extra_args[self->n_def_args - 1]; + for (int i = self->n_def_args; i > 0; i--) { + if (*d != MP_OBJ_NULL) { + *d-- = *s--; + } + } + DEBUG_printf("Args after filling defaults: "); + dump_args(flat_args, self->n_args); + + // Now check that all mandatory args specified + while (d >= flat_args) { + if (*d-- == MP_OBJ_NULL) { + nlr_jump(mp_obj_new_exception_msg_varg(&mp_type_TypeError, + "function missing required positional argument #%d", d - flat_args)); + } + } + + args = flat_args; + n_args = self->n_args; + + if (self->takes_kw_args) { + extra_args[n_extra_args] = dict; + n_extra_args += 1; } - extra_args[n_extra_args] = dict; - n_extra_args += 1; } else { // no keyword arguments given if (self->takes_kw_args) { @@ -198,6 +280,9 @@ STATIC mp_obj_t fun_bc_call(mp_obj_t self_in, uint n_args, uint n_kw, const mp_o mp_map_t *old_globals = rt_globals_get(); rt_globals_set(self->globals); mp_obj_t result; + DEBUG_printf("Calling: args=%p, n_args=%d, extra_args=%p, n_extra_args=%d\n", args, n_args, extra_args, n_extra_args); + dump_args(args, n_args); + dump_args(extra_args, n_extra_args); mp_vm_return_kind_t vm_return_kind = mp_execute_byte_code(self->bytecode, args, n_args, extra_args, n_extra_args, self->n_state, &result); rt_globals_set(old_globals); @@ -217,7 +302,7 @@ const mp_obj_type_t fun_bc_type = { .call = fun_bc_call, }; -mp_obj_t mp_obj_new_fun_bc(uint scope_flags, uint n_args, mp_obj_t def_args_in, uint n_state, const byte *code) { +mp_obj_t mp_obj_new_fun_bc(uint scope_flags, qstr *args, uint n_args, mp_obj_t def_args_in, uint n_state, const byte *code) { uint n_def_args = 0; uint n_extra_args = 0; mp_obj_tuple_t *def_args = def_args_in; @@ -234,6 +319,7 @@ mp_obj_t mp_obj_new_fun_bc(uint scope_flags, uint n_args, mp_obj_t def_args_in, mp_obj_fun_bc_t *o = m_new_obj_var(mp_obj_fun_bc_t, mp_obj_t, n_extra_args); o->base.type = &fun_bc_type; o->globals = rt_globals_get(); + o->args = args; o->n_args = n_args; o->n_def_args = n_def_args; o->takes_var_args = (scope_flags & MP_SCOPE_FLAG_VARARGS) != 0; diff --git a/py/runtime.c b/py/runtime.c index 9fc0b97088..b08ae3d4e7 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -64,6 +64,7 @@ typedef struct _mp_code_t { void *fun; } u_inline_asm; }; + qstr *arg_names; } mp_code_t; STATIC uint next_unique_code_id; @@ -242,7 +243,7 @@ STATIC void alloc_unique_codes(void) { } } -void rt_assign_byte_code(uint unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, uint scope_flags) { +void rt_assign_byte_code(uint unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, uint scope_flags, qstr *arg_names) { alloc_unique_codes(); assert(1 <= unique_code_id && unique_code_id < next_unique_code_id && unique_codes[unique_code_id].kind == MP_CODE_NONE); @@ -252,6 +253,7 @@ void rt_assign_byte_code(uint unique_code_id, byte *code, uint len, int n_args, unique_codes[unique_code_id].n_state = n_locals + n_stack; unique_codes[unique_code_id].u_byte.code = code; unique_codes[unique_code_id].u_byte.len = len; + unique_codes[unique_code_id].arg_names = arg_names; //printf("byte code: %d bytes\n", len); @@ -714,7 +716,7 @@ mp_obj_t rt_make_function_from_id(int unique_code_id, mp_obj_t def_args) { mp_obj_t fun; switch (c->kind) { case MP_CODE_BYTE: - fun = mp_obj_new_fun_bc(c->scope_flags, c->n_args, def_args, c->n_state, c->u_byte.code); + fun = mp_obj_new_fun_bc(c->scope_flags, c->arg_names, c->n_args, def_args, c->n_state, c->u_byte.code); break; case MP_CODE_NATIVE: fun = rt_make_function_n(c->n_args, c->u_native.fun); diff --git a/py/runtime0.h b/py/runtime0.h index ca88fc13b1..07fcf0705e 100644 --- a/py/runtime0.h +++ b/py/runtime0.h @@ -97,6 +97,6 @@ extern void *const rt_fun_table[RT_F_NUMBER_OF]; void rt_init(void); void rt_deinit(void); uint rt_get_unique_code_id(void); -void rt_assign_byte_code(uint unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, uint scope_flags); +void rt_assign_byte_code(uint unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, uint scope_flags, qstr *arg_names); void rt_assign_native_code(uint unique_code_id, void *f, uint len, int n_args); void rt_assign_inline_asm_code(uint unique_code_id, void *f, uint len, int n_args); diff --git a/tests/basics/fun-kwargs.py b/tests/basics/fun-kwargs.py new file mode 100644 index 0000000000..9f4f2b7d31 --- /dev/null +++ b/tests/basics/fun-kwargs.py @@ -0,0 +1,35 @@ +def f1(a): + print(a) + +f1(123) +f1(a=123) +try: + f1(b=123) +except TypeError: + print("TypeError") + +def f2(a, b): + print(a, b) + +f2(1, 2) +f2(a=3, b=4) +f2(b=5, a=6) +f2(7, b=8) +try: + f2(9, a=10) +except TypeError: + print("TypeError") + +def f3(a, b, *args): + print(a, b, args) + + +f3(1, b=3) +try: + f3(1, a=3) +except TypeError: + print("TypeError") +try: + f3(1, 2, 3, 4, a=5) +except TypeError: + print("TypeError")