diff --git a/README.md b/README.md index 1f29857..9a28b4f 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ int main(void) { fact = partial_fact; } - printf("%ld! = ", n); + printf("%d! = ", n); bigint_print(fact); printf("\n"); diff --git a/src/bigint.c b/src/bigint.c index a2d6ac3..e2b0c82 100644 --- a/src/bigint.c +++ b/src/bigint.c @@ -11,14 +11,6 @@ #define IS_DIGIT(c) ((c) >= '0') && ((c) <= '9') -#define DESTROY_IF(p) \ - do { \ - if ((p) && (p) != result.value.number) { \ - bigint_destroy((p)); \ - (p) = NULL; \ - } \ - } while (0) - #include #include #include @@ -35,6 +27,8 @@ static bigint_result_t bigint_shift_left(const bigint_t *num, size_t n); static bigint_result_t bigint_split(const bigint_t *num, size_t m, bigint_t **high, bigint_t **low); static bigint_result_t bigint_karatsuba_base(const bigint_t *x, const bigint_t *y); static bigint_result_t bigint_karatsuba(const bigint_t *x, const bigint_t *y); +static bigint_result_t bigint_shift_right(const bigint_t *num, size_t n); +static bigint_result_t bigint_reciprocal(const bigint_t *num, size_t precision); /** * bigint_from_int @@ -882,6 +876,116 @@ bigint_result_t bigint_prod(const bigint_t *x, const bigint_t *y) { return result; } +/** + * bigint_divmod + * @x: a valid non-null big integer + * @y: a valid non-null big integer + * + * Computes division with remainder + * + * Returns a bigint_result_t data type + */ +bigint_result_t bigint_divmod(const bigint_t *x, const bigint_t *y) { + bigint_result_t result = {0}; + bigint_result_t tmp_res = {0}; + + // Intermediate results + bigint_t *quotient = NULL; + bigint_t *y_times_q = NULL; + bigint_t *remainder = NULL; + + if (x == NULL || y == NULL) { + result.status = BIGINT_ERR_INVALID; + SET_MSG(result, "Invalid big numbers"); + + return result; + } + + // Check for division by zero + const size_t y_size = vector_size(y->digits); + if (y_size == 0) { + result.status = BIGINT_ERR_DIV_BY_ZERO; + SET_MSG(result, "Division by zero"); + + return result; + } + + if (y_size == 1) { + vector_result_t y_val_res = vector_get(y->digits, 0); + if (y_val_res.status != VECTOR_OK) { + result.status = BIGINT_ERR_INVALID; + COPY_MSG(result, y_val_res.message); + + return result; + } + + int *y_val = (int*)y_val_res.value.element; + if (*y_val == 0) { + result.status = BIGINT_ERR_DIV_BY_ZERO; + SET_MSG(result, "Division by zero"); + + return result; + } + } + + // |x| < |y| then quotient is 0 and remainder is x + tmp_res = bigint_compare_abs(x, y); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + + if (tmp_res.value.compare_status < 0) { + tmp_res = bigint_from_int(0); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + quotient = tmp_res.value.number; + + tmp_res = bigint_clone(x); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + remainder = tmp_res.value.number; + + result.value.division.quotient = quotient; + result.value.division.remainder = remainder; + result.status = BIGINT_OK; + SET_MSG(result, "Division between big integers was successful"); + + return result; + } + + tmp_res = bigint_div(x, y); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + quotient = tmp_res.value.number; + + // Computed r = x - y * q + tmp_res = bigint_prod(y, quotient); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + y_times_q = tmp_res.value.number; + + tmp_res = bigint_sub(x, y_times_q); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + remainder = tmp_res.value.number; + + // Ensure that remainder has correct sign (i.e., same as dividend x) + // In C-style division, sign(remainder) = sign(dividend) + remainder->is_negative = x->is_negative; + + tmp_res = bigint_trim_zeros(remainder); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + + result.value.division.quotient = quotient; + result.value.division.remainder = remainder; + result.status = BIGINT_OK; + SET_MSG(result, "Division between big integers was successful"); + + bigint_destroy(y_times_q); + + return result; + +cleanup: + if (quotient) { bigint_destroy(quotient); } + if (y_times_q) { bigint_destroy(y_times_q); } + if (remainder) { bigint_destroy(remainder); } + + return result; +} + /** * bigint_shift_left * @num: a non-null big integer @@ -1255,6 +1359,14 @@ bigint_result_t bigint_karatsuba_base(const bigint_t *x, const bigint_t *y) { */ bigint_result_t bigint_karatsuba(const bigint_t *x, const bigint_t *y) { bigint_result_t result = {0}; + bigint_result_t tmp_res = {0}; + + if (x == NULL || y == NULL) { + result.status = BIGINT_ERR_INVALID; + SET_MSG(result, "Invalid big integers"); + + return result; + } const size_t x_size = vector_size(x->digits); const size_t y_size = vector_size(y->digits); @@ -1276,14 +1388,12 @@ bigint_result_t bigint_karatsuba(const bigint_t *x, const bigint_t *y) { bigint_t *z2_shifted = NULL, *z1_shifted = NULL; bigint_t *temp = NULL, *product = NULL; - bigint_result_t tmp_res = {0}; - // Split x = x1 * BASE^pivot + x0 tmp_res = bigint_split(x, pivot, &x1, &x0); if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } // Split y = y1 * BASE^pivot + y0 - tmp_res = bigint_split(x, pivot, &y1, &y0); + tmp_res = bigint_split(y, pivot, &y1, &y0); if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } // Perform karatsuba's trick @@ -1332,18 +1442,251 @@ bigint_result_t bigint_karatsuba(const bigint_t *x, const bigint_t *y) { if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } product = tmp_res.value.number; + // Destroy intermediate allocations except for the product + bigint_destroy(x1); bigint_destroy(x0); + bigint_destroy(y1); bigint_destroy(y0); + bigint_destroy(z0); bigint_destroy(z2); + bigint_destroy(x_sum); bigint_destroy(y_sum); + bigint_destroy(z1_temp); bigint_destroy(z1_sub1); + bigint_destroy(z1); bigint_destroy(z2_shifted); + bigint_destroy(z1_shifted); bigint_destroy(temp); + result.value.number = product; result.status = BIGINT_OK; SET_MSG(result, "Product between big integers was successful"); -cleanup: // Destroy intermediate allocations except for the product - DESTROY_IF(x1); DESTROY_IF(x0); - DESTROY_IF(y1); DESTROY_IF(y0); - DESTROY_IF(z0); DESTROY_IF(z2); - DESTROY_IF(x_sum); DESTROY_IF(y_sum); - DESTROY_IF(z1_temp); DESTROY_IF(z1_sub1); DESTROY_IF(z1); - DESTROY_IF(z2_shifted); DESTROY_IF(z1_shifted); - DESTROY_IF(temp); +cleanup: // Destroy intermediate allocations on error + if (x1) { bigint_destroy(x1); } + if (x0) { bigint_destroy(x0); } + if (y1) { bigint_destroy(y1); } + if (y0) { bigint_destroy(y0); } + if (z0) { bigint_destroy(z0); } + if (z2) { bigint_destroy(z2); } + if (x_sum) { bigint_destroy(x_sum); } + if (y_sum) { bigint_destroy(y_sum); } + if (z1_temp) { bigint_destroy(z1_temp); } + if (z1_sub1) { bigint_destroy(z1_sub1); } + if (z1) { bigint_destroy(z1); } + if (z2_shifted) { bigint_destroy(z2_shifted); } + if (z1_shifted) { bigint_destroy(z1_shifted); } + if (temp) { bigint_destroy(temp); } + if (product) { bigint_destroy(product); } + + return result; +} + +/** + * bigint_shift_right + * @num: a valid non-null big integer + * @n: number of digits to shift + * + * Shifts right by @n digits (i.e., divide by BASE^n) + * + * Returns a bigint_result_t data type + */ +bigint_result_t bigint_shift_right(const bigint_t *num, size_t n) { + bigint_result_t result = {0}; + + const size_t size = vector_size(num->digits); + + if (n >= size) return bigint_from_int(0); + if (n == 0) return bigint_clone(num); + + bigint_t *shifted = malloc(sizeof(bigint_t)); + if (shifted == NULL) { + result.status = BIGINT_ERR_ALLOCATE; + SET_MSG(result, "Failed to allocate memory for big integer"); + + return result; + } + + vector_result_t vec_res = vector_new(size - n, sizeof(int)); + if (vec_res.status != VECTOR_OK) { + free(shifted); + result.status = BIGINT_ERR_INVALID; + COPY_MSG(result, vec_res.message); + + return result; + } + + shifted->digits = vec_res.value.vector; + shifted->is_negative = num->is_negative; + + // Copy digits from position 'n' onwards + for (size_t idx = n; idx < size; idx++) { + vector_result_t vec_res = vector_get(num->digits, idx); + if (vec_res.status != VECTOR_OK) { + vector_destroy(shifted->digits); + free(shifted); + result.status = BIGINT_ERR_INVALID; + COPY_MSG(result, vec_res.message); + + return result; + } + + int *digit = (int*)vec_res.value.element; + + vector_result_t push_res = vector_push(shifted->digits, digit); + if (push_res.status != VECTOR_OK) { + vector_destroy(shifted->digits); + free(shifted); + result.status = BIGINT_ERR_INVALID; + COPY_MSG(result, push_res.message); + + return result; + } + } + + bigint_result_t trim_res = bigint_trim_zeros(shifted); + if (trim_res.status != BIGINT_OK) { + vector_destroy(shifted->digits); + free(shifted); + + return trim_res; + } + + result.value.number = shifted; + result.status = BIGINT_OK; + SET_MSG(result, "Big integer shifted successfully"); + + return result; +} + +/** + * bigint_reciprocal + * @num: a valid non-null big integer + * @precision: the precision of the computation + * + * Compute the reciprocal using Newton-Raphson algorithm. + * It calculates 1/num with precision @precision, returning + * floor(BASE^(2 * @precision) / num) + * + * Returns a bigint_result_t data type + */ +bigint_result_t bigint_reciprocal(const bigint_t *num, size_t precision) { + bigint_result_t result = {0}; + bigint_result_t tmp_res = {0}; + + // Results of each steps + bigint_t *x = NULL; + bigint_t *scale = NULL; + bigint_t *two = NULL; + bigint_t *two_scaled = NULL; + bigint_t *dx = NULL; + bigint_t *two_minus_dx = NULL; + bigint_t *x_new_tmp = NULL; + bigint_t *x_new = NULL; + + if (num == NULL) { + result.status = BIGINT_ERR_INVALID; + SET_MSG(result, "Invalid big integer"); + + return result; + } + + const size_t num_size = vector_size(num->digits); + // Get most significant digit + vector_result_t msd_res = vector_get(num->digits, num_size - 1); + if (msd_res.status != VECTOR_OK) { + result.status = BIGINT_ERR_INVALID; + COPY_MSG(result, msd_res.message); + + return result; + } + + int *msd = (int*)msd_res.value.element; + + // x = floor(BASE^2 / (msd + 1)) + const long long initial_val = ((long long)BIGINT_BASE * (long long)BIGINT_BASE) / ((long long)(*msd) + 1LL); + tmp_res = bigint_from_int(initial_val); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + x = tmp_res.value.number; + + tmp_res = bigint_from_int(1); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + scale = tmp_res.value.number; + + // Scale to proper precision. That is scale x by BASE^(2 * precision - 2) + // in order to reach BASE^(2 * precision) magnitude + if (precision > 1) { + tmp_res = bigint_shift_left(scale, 2 * precision - 2); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + bigint_destroy(scale); + scale = tmp_res.value.number; + + tmp_res = bigint_prod(x, scale); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + bigint_destroy(x); + x = tmp_res.value.number; + } + + // two_scaled = 2 * BASE^(2 * precision) + tmp_res = bigint_from_int(2); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + two = tmp_res.value.number; + + tmp_res = bigint_shift_left(two, 2 * precision); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + + bigint_destroy(two); + two = NULL; + two_scaled = tmp_res.value.number; + + // Determine the number of Newton-Raphson iterations + size_t iterations = 0; + size_t target = precision; + while ((1ULL << iterations) < target) { iterations++; } + iterations += 2; // Add a few more just to be sure + + // x_{n+1} = x_n * (2 * BASE^(2P) - d * x_n) / BASE^(2P) + for (size_t it = 0; it < iterations; it++) { + // dx = d * x + tmp_res = bigint_prod(num, x); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + dx = tmp_res.value.number; + + // two_minus_dx = 2 * BASE^(2P) - dx + tmp_res = bigint_sub(two_scaled, dx); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + two_minus_dx = tmp_res.value.number; + + // x_new_temp = x * (two_minus_dx) + tmp_res = bigint_prod(x, two_minus_dx); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + x_new_tmp = tmp_res.value.number; + + // x_new = x_new_temp >> (2 * precision) + tmp_res = bigint_shift_right(x_new_tmp, 2 * precision); + if (tmp_res.status != BIGINT_OK) { result = tmp_res; goto cleanup; } + x_new = tmp_res.value.number; + + // Rotation pass: replace x with x_new and free intermediates + bigint_destroy(x); + x = x_new; + x_new = NULL; + + bigint_destroy(dx); dx = NULL; + bigint_destroy(two_minus_dx); two_minus_dx = NULL; + bigint_destroy(x_new_tmp); x_new_tmp = NULL; + } + + bigint_destroy(scale); + bigint_destroy(two_scaled); + + result.value.number = x; + result.status = BIGINT_OK; + SET_MSG(result, "Reciprocal computed successfully"); + return result; + +cleanup: + if (x) { bigint_destroy(x); } + if (scale) { bigint_destroy(scale); } + if (two) { bigint_destroy(two); } + if (two_scaled) { bigint_destroy(two_scaled); } + if (dx) { bigint_destroy(dx); } + if (two_minus_dx) { bigint_destroy(two_minus_dx); } + if (x_new_tmp) { bigint_destroy(x_new_tmp); } + if (x_new) { bigint_destroy(x_new); } return result; } diff --git a/src/bigint.h b/src/bigint.h index dc06267..8e60b49 100644 --- a/src/bigint.h +++ b/src/bigint.h @@ -24,11 +24,17 @@ typedef struct { bool is_negative; } bigint_t; +typedef struct { + bigint_t *quotient; + bigint_t *remainder; +} div_result_t; + typedef struct { bigint_status_t status; uint8_t message[RESULT_MSG_SIZE]; union { bigint_t *number; + div_result_t division; int8_t compare_status; char *string_num; } value; @@ -46,6 +52,7 @@ bigint_result_t bigint_compare(const bigint_t *x, const bigint_t *y); bigint_result_t bigint_add(const bigint_t *x, const bigint_t *y); bigint_result_t bigint_sub(const bigint_t *x, const bigint_t *y); bigint_result_t bigint_prod(const bigint_t *x, const bigint_t *y); +bigint_result_t bigint_divmod(const bigint_t *x, const bigint_t *y); bigint_result_t bigint_destroy(bigint_t *number); bigint_result_t bigint_print(const bigint_t *number); diff --git a/tests/test_vector b/tests/test_vector deleted file mode 100755 index c9066fd..0000000 Binary files a/tests/test_vector and /dev/null differ