// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "src/bigint/bigint-internal.h"
#include "src/bigint/digit-arithmetic.h"
#include "src/bigint/util.h"
#include "src/bigint/vector-arithmetic.h"

namespace v8 {
namespace bigint {

void BitwiseAnd_PosPos(RWDigits Z, Digits X, Digits Y) {
  int pairs = std::min(X.len(), Y.len());
  DCHECK(Z.len() >= pairs);
  int i = 0;
  for (; i < pairs; i++) Z[i] = X[i] & Y[i];
  for (; i < Z.len(); i++) Z[i] = 0;
}

void BitwiseAnd_NegNeg(RWDigits Z, Digits X, Digits Y) {
  // (-x) & (-y) == ~(x-1) & ~(y-1)
  //             == ~((x-1) | (y-1))
  //             == -(((x-1) | (y-1)) + 1)
  int pairs = std::min(X.len(), Y.len());
  digit_t x_borrow = 1;
  digit_t y_borrow = 1;
  int i = 0;
  for (; i < pairs; i++) {
    Z[i] = digit_sub(X[i], x_borrow, &x_borrow) |
           digit_sub(Y[i], y_borrow, &y_borrow);
  }
  // (At least) one of the next two loops will perform zero iterations:
  for (; i < X.len(); i++) Z[i] = digit_sub(X[i], x_borrow, &x_borrow);
  for (; i < Y.len(); i++) Z[i] = digit_sub(Y[i], y_borrow, &y_borrow);
  DCHECK(x_borrow == 0);
  DCHECK(y_borrow == 0);
  for (; i < Z.len(); i++) Z[i] = 0;
  Add(Z, 1);
}

void BitwiseAnd_PosNeg(RWDigits Z, Digits X, Digits Y) {
  // x & (-y) == x & ~(y-1)
  int pairs = std::min(X.len(), Y.len());
  digit_t borrow = 1;
  int i = 0;
  for (; i < pairs; i++) Z[i] = X[i] & ~digit_sub(Y[i], borrow, &borrow);
  for (; i < X.len(); i++) Z[i] = X[i];
  for (; i < Z.len(); i++) Z[i] = 0;
}

void BitwiseOr_PosPos(RWDigits Z, Digits X, Digits Y) {
  int pairs = std::min(X.len(), Y.len());
  int i = 0;
  for (; i < pairs; i++) Z[i] = X[i] | Y[i];
  // (At least) one of the next two loops will perform zero iterations:
  for (; i < X.len(); i++) Z[i] = X[i];
  for (; i < Y.len(); i++) Z[i] = Y[i];
  for (; i < Z.len(); i++) Z[i] = 0;
}

void BitwiseOr_NegNeg(RWDigits Z, Digits X, Digits Y) {
  // (-x) | (-y) == ~(x-1) | ~(y-1)
  //             == ~((x-1) & (y-1))
  //             == -(((x-1) & (y-1)) + 1)
  int pairs = std::min(X.len(), Y.len());
  digit_t x_borrow = 1;
  digit_t y_borrow = 1;
  int i = 0;
  for (; i < pairs; i++) {
    Z[i] = digit_sub(X[i], x_borrow, &x_borrow) &
           digit_sub(Y[i], y_borrow, &y_borrow);
  }
  // Any leftover borrows don't matter, the '&' would drop them anyway.
  for (; i < Z.len(); i++) Z[i] = 0;
  Add(Z, 1);
}

void BitwiseOr_PosNeg(RWDigits Z, Digits X, Digits Y) {
  // x | (-y) == x | ~(y-1) == ~((y-1) &~ x) == -(((y-1) &~ x) + 1)
  int pairs = std::min(X.len(), Y.len());
  digit_t borrow = 1;
  int i = 0;
  for (; i < pairs; i++) Z[i] = digit_sub(Y[i], borrow, &borrow) & ~X[i];
  for (; i < Y.len(); i++) Z[i] = digit_sub(Y[i], borrow, &borrow);
  DCHECK(borrow == 0);
  for (; i < Z.len(); i++) Z[i] = 0;
  Add(Z, 1);
}

void BitwiseXor_PosPos(RWDigits Z, Digits X, Digits Y) {
  int pairs = X.len();
  if (Y.len() < X.len()) {
    std::swap(X, Y);
    pairs = X.len();
  }
  DCHECK(X.len() <= Y.len());
  int i = 0;
  for (; i < pairs; i++) Z[i] = X[i] ^ Y[i];
  for (; i < Y.len(); i++) Z[i] = Y[i];
  for (; i < Z.len(); i++) Z[i] = 0;
}

void BitwiseXor_NegNeg(RWDigits Z, Digits X, Digits Y) {
  // (-x) ^ (-y) == ~(x-1) ^ ~(y-1) == (x-1) ^ (y-1)
  int pairs = std::min(X.len(), Y.len());
  digit_t x_borrow = 1;
  digit_t y_borrow = 1;
  int i = 0;
  for (; i < pairs; i++) {
    Z[i] = digit_sub(X[i], x_borrow, &x_borrow) ^
           digit_sub(Y[i], y_borrow, &y_borrow);
  }
  // (At least) one of the next two loops will perform zero iterations:
  for (; i < X.len(); i++) Z[i] = digit_sub(X[i], x_borrow, &x_borrow);
  for (; i < Y.len(); i++) Z[i] = digit_sub(Y[i], y_borrow, &y_borrow);
  DCHECK(x_borrow == 0);
  DCHECK(y_borrow == 0);
  for (; i < Z.len(); i++) Z[i] = 0;
}

void BitwiseXor_PosNeg(RWDigits Z, Digits X, Digits Y) {
  // x ^ (-y) == x ^ ~(y-1) == ~(x ^ (y-1)) == -((x ^ (y-1)) + 1)
  int pairs = std::min(X.len(), Y.len());
  digit_t borrow = 1;
  int i = 0;
  for (; i < pairs; i++) Z[i] = X[i] ^ digit_sub(Y[i], borrow, &borrow);
  // (At least) one of the next two loops will perform zero iterations:
  for (; i < X.len(); i++) Z[i] = X[i];
  for (; i < Y.len(); i++) Z[i] = digit_sub(Y[i], borrow, &borrow);
  DCHECK(borrow == 0);
  for (; i < Z.len(); i++) Z[i] = 0;
  Add(Z, 1);
}

namespace {

// Z := (least significant n bits of X).
void TruncateToNBits(RWDigits Z, Digits X, int n) {
  int digits = DIV_CEIL(n, kDigitBits);
  int bits = n % kDigitBits;
  // Copy all digits except the MSD.
  int last = digits - 1;
  for (int i = 0; i < last; i++) {
    Z[i] = X[i];
  }
  // The MSD might contain extra bits that we don't want.
  digit_t msd = X[last];
  if (bits != 0) {
    int drop = kDigitBits - bits;
    msd = (msd << drop) >> drop;
  }
  Z[last] = msd;
}

// Z := 2**n - (least significant n bits of X).
void TruncateAndSubFromPowerOfTwo(RWDigits Z, Digits X, int n) {
  int digits = DIV_CEIL(n, kDigitBits);
  int bits = n % kDigitBits;
  // Process all digits except the MSD. Take X's digits, then simulate leading
  // zeroes.
  int last = digits - 1;
  int have_x = std::min(last, X.len());
  digit_t borrow = 0;
  int i = 0;
  for (; i < have_x; i++) Z[i] = digit_sub2(0, X[i], borrow, &borrow);
  for (; i < last; i++) Z[i] = digit_sub(0, borrow, &borrow);

  // The MSD might contain extra bits that we don't want.
  digit_t msd = last < X.len() ? X[last] : 0;
  if (bits == 0) {
    Z[last] = digit_sub2(0, msd, borrow, &borrow);
  } else {
    int drop = kDigitBits - bits;
    msd = (msd << drop) >> drop;
    digit_t minuend_msd = static_cast<digit_t>(1) << bits;
    digit_t result_msd = digit_sub2(minuend_msd, msd, borrow, &borrow);
    DCHECK(borrow == 0);  // result < 2^n.
    // If all subtracted bits were zero, we have to get rid of the
    // materialized minuend_msd again.
    Z[last] = result_msd & (minuend_msd - 1);
  }
}

}  // namespace

// Returns -1 when the operation would return X unchanged.
int AsIntNResultLength(Digits X, bool x_negative, int n) {
  int needed_digits = DIV_CEIL(n, kDigitBits);
  // Generally: decide based on number of digits, and bits in the top digit.
  if (X.len() < needed_digits) return -1;
  if (X.len() > needed_digits) return needed_digits;
  digit_t top_digit = X[needed_digits - 1];
  digit_t compare_digit = digit_t{1} << ((n - 1) % kDigitBits);
  if (top_digit < compare_digit) return -1;
  if (top_digit > compare_digit) return needed_digits;
  // Special case: if X == -2**(n-1), truncation is a no-op.
  if (!x_negative) return needed_digits;
  for (int i = needed_digits - 2; i >= 0; i--) {
    if (X[i] != 0) return needed_digits;
  }
  return -1;
}

bool AsIntN(RWDigits Z, Digits X, bool x_negative, int n) {
  DCHECK(X.len() > 0);
  DCHECK(n > 0);
  DCHECK(AsIntNResultLength(X, x_negative, n) > 0);
  int needed_digits = DIV_CEIL(n, kDigitBits);
  digit_t top_digit = X[needed_digits - 1];
  digit_t compare_digit = digit_t{1} << ((n - 1) % kDigitBits);
  // The canonical algorithm would be: convert negative numbers to two's
  // complement representation, truncate, convert back to sign+magnitude. To
  // avoid the conversions, we predict what the result would be:
  // When the (n-1)th bit is not set:
  //  - truncate the absolute value
  //  - preserve the sign.
  // When the (n-1)th bit is set:
  //  - subtract the truncated absolute value from 2**n to simulate two's
  //    complement representation
  //  - flip the sign, unless it's the special case where the input is negative
  //    and the result is the minimum n-bit integer. E.g. asIntN(3, -12) => -4.
  bool has_bit = (top_digit & compare_digit) == compare_digit;
  if (!has_bit) {
    TruncateToNBits(Z, X, n);
    return x_negative;
  }
  TruncateAndSubFromPowerOfTwo(Z, X, n);
  if (!x_negative) return true;  // Result is negative.
  // Scan for the special case (see above): if all bits below the (n-1)th
  // digit are zero, the result is negative.
  if ((top_digit & (compare_digit - 1)) != 0) return false;
  for (int i = needed_digits - 2; i >= 0; i--) {
    if (X[i] != 0) return false;
  }
  return true;
}

// Returns -1 when the operation would return X unchanged.
int AsUintN_Pos_ResultLength(Digits X, int n) {
  int needed_digits = DIV_CEIL(n, kDigitBits);
  if (X.len() < needed_digits) return -1;
  if (X.len() > needed_digits) return needed_digits;
  int bits_in_top_digit = n % kDigitBits;
  if (bits_in_top_digit == 0) return -1;
  digit_t top_digit = X[needed_digits - 1];
  if ((top_digit >> bits_in_top_digit) == 0) return -1;
  return needed_digits;
}

void AsUintN_Pos(RWDigits Z, Digits X, int n) {
  DCHECK(AsUintN_Pos_ResultLength(X, n) > 0);
  TruncateToNBits(Z, X, n);
}

void AsUintN_Neg(RWDigits Z, Digits X, int n) {
  TruncateAndSubFromPowerOfTwo(Z, X, n);
}

}  // namespace bigint
}  // namespace v8