diff --git a/py/mpz.c b/py/mpz.c index d6aeafd102..16198730af 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -51,6 +51,9 @@ STATIC int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint STATIC uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { uint n_whole = (n + DIG_SIZE - 1) / DIG_SIZE; uint n_part = n % DIG_SIZE; + if (n_part == 0) { + n_part = DIG_SIZE; + } // start from the high end of the digit arrays idig += jlen + n_whole - 1; @@ -67,7 +70,7 @@ STATIC uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) { // store remaining bits *idig = d >> (DIG_SIZE - n_part); idig -= n_whole - 1; - memset(idig, 0, n_whole - 1); + memset(idig, 0, (n_whole - 1) * sizeof(mpz_dig_t)); // work out length of result jlen += n_whole; @@ -412,6 +415,12 @@ mpz_t *mpz_from_int(machine_int_t val) { return z; } +mpz_t *mpz_from_ll(long long val) { + mpz_t *z = mpz_zero(); + mpz_set_from_ll(z, val); + return z; +} + mpz_t *mpz_from_str(const char *str, uint len, bool neg, uint base) { mpz_t *z = mpz_zero(); mpz_set_from_str(z, str, len, neg, base); @@ -469,17 +478,38 @@ void mpz_set(mpz_t *dest, const mpz_t *src) { void mpz_set_from_int(mpz_t *z, machine_int_t val) { mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT); + machine_uint_t uval; if (val < 0) { z->neg = 1; - val = -val; + uval = -val; } else { z->neg = 0; + uval = val; } z->len = 0; - while (val > 0) { - z->dig[z->len++] = val & DIG_MASK; - val >>= DIG_SIZE; + while (uval > 0) { + z->dig[z->len++] = uval & DIG_MASK; + uval >>= DIG_SIZE; + } +} + +void mpz_set_from_ll(mpz_t *z, long long val) { + mpz_need_dig(z, MPZ_NUM_DIG_FOR_LL); + + unsigned long long uval; + if (val < 0) { + z->neg = 1; + uval = -val; + } else { + z->neg = 0; + uval = val; + } + + z->len = 0; + while (uval > 0) { + z->dig[z->len++] = uval & DIG_MASK; + uval >>= DIG_SIZE; } } diff --git a/py/mpz.h b/py/mpz.h index 0ef1ad10db..afd46cfdea 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -12,6 +12,7 @@ typedef struct _mpz_t { #define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15 #define MPZ_NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / MPZ_DIG_SIZE + 1) +#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1) // convenience macro to declare an mpz with a digit array from the stack, initialised by an integer #define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val); @@ -23,6 +24,7 @@ void mpz_deinit(mpz_t *z); mpz_t *mpz_zero(); mpz_t *mpz_from_int(machine_int_t i); +mpz_t *mpz_from_ll(long long i); mpz_t *mpz_from_str(const char *str, uint len, bool neg, uint base); void mpz_free(mpz_t *z); @@ -30,6 +32,7 @@ mpz_t *mpz_clone(const mpz_t *src); void mpz_set(mpz_t *dest, const mpz_t *src); void mpz_set_from_int(mpz_t *z, machine_int_t src); +void mpz_set_from_ll(mpz_t *z, long long i); uint mpz_set_from_str(mpz_t *z, const char *str, uint len, bool neg, uint base); bool mpz_is_zero(const mpz_t *z);