Commit dd5e5535 authored by Jakob Kummerow's avatar Jakob Kummerow Committed by V8 LUCI CQ

[bigint] Faster parsing from long strings

Combining parts in a balanced-binary-tree like order allows us to
use fast multiplication algorithms.

Bug: v8:11515
Change-Id: I6829929671770f009f10f6f3b383501fede476ab
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3049079Reviewed-by: 's avatarMaya Lekova <mslekova@chromium.org>
Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/main@{#76404}
parent 45424f1a
...@@ -22,6 +22,7 @@ constexpr int kNewtonInversionThreshold = 50; ...@@ -22,6 +22,7 @@ constexpr int kNewtonInversionThreshold = 50;
// kBarrettThreshold is defined in bigint.h. // kBarrettThreshold is defined in bigint.h.
constexpr int kToStringFastThreshold = 43; constexpr int kToStringFastThreshold = 43;
constexpr int kFromStringLargeThreshold = 300;
class ProcessorImpl : public Processor { class ProcessorImpl : public Processor {
public: public:
...@@ -69,6 +70,7 @@ class ProcessorImpl : public Processor { ...@@ -69,6 +70,7 @@ class ProcessorImpl : public Processor {
void FromString(RWDigits Z, FromStringAccumulator* accumulator); void FromString(RWDigits Z, FromStringAccumulator* accumulator);
void FromStringClassic(RWDigits Z, FromStringAccumulator* accumulator); void FromStringClassic(RWDigits Z, FromStringAccumulator* accumulator);
void FromStringLarge(RWDigits Z, FromStringAccumulator* accumulator);
bool should_terminate() { return status_ == Status::kInterrupted; } bool should_terminate() { return status_ == Status::kInterrupted; }
......
...@@ -262,6 +262,8 @@ class Processor { ...@@ -262,6 +262,8 @@ class Processor {
// upon return will be set to the actual length of the result string. // upon return will be set to the actual length of the result string.
Status ToString(char* out, int* out_length, Digits X, int radix, bool sign); Status ToString(char* out, int* out_length, Digits X, int radix, bool sign);
// Z := the contents of {accumulator}.
// Assume that this leaves {accumulator} in unusable state.
Status FromString(RWDigits Z, FromStringAccumulator* accumulator); Status FromString(RWDigits Z, FromStringAccumulator* accumulator);
}; };
......
...@@ -40,7 +40,6 @@ void ProcessorImpl::FromStringClassic(RWDigits Z, ...@@ -40,7 +40,6 @@ void ProcessorImpl::FromStringClassic(RWDigits Z,
// Parts are stored on the heap. // Parts are stored on the heap.
for (int i = 1; i < num_heap_parts - 1; i++) { for (int i = 1; i < num_heap_parts - 1; i++) {
MultiplySingle(Z, already_set, max_multiplier); MultiplySingle(Z, already_set, max_multiplier);
if (should_terminate()) return;
Add(Z, accumulator->heap_parts_[i]); Add(Z, accumulator->heap_parts_[i]);
already_set.set_len(already_set.len() + 1); already_set.set_len(already_set.len() + 1);
} }
...@@ -48,6 +47,171 @@ void ProcessorImpl::FromStringClassic(RWDigits Z, ...@@ -48,6 +47,171 @@ void ProcessorImpl::FromStringClassic(RWDigits Z,
Add(Z, accumulator->heap_parts_.back()); Add(Z, accumulator->heap_parts_.back());
} }
// The fast algorithm: combine parts in a balanced-binary-tree like order:
// Multiply-and-add neighboring pairs of parts, then loop, until only one
// part is left. The benefit is that the multiplications will have inputs of
// similar sizes, which makes them amenable to fast multiplication algorithms.
// We have to do more multiplications than the classic algorithm though,
// because we also have to multiply the multipliers.
// Optimizations:
// - We can skip the multiplier for the first part, because we never need it.
// - Most multipliers are the same; we can avoid repeated multiplications and
// just copy the previous result. (In theory we could even de-dupe them, but
// as the parts/multipliers grow, we'll need most of the memory anyway.)
// Copied results are marked with a * below.
// - We can re-use memory using a system of three buffers whose usage rotates:
// - one is considered empty, and is overwritten with the new parts,
// - one holds the multipliers (and will be "empty" in the next round), and
// - one initially holds the parts and is overwritten with the new multipliers
// Parts and multipliers both grow in each iteration, and get fewer, so we
// use the space of two adjacent old chunks for one new chunk.
// Since the {heap_parts_} vectors has the right size, and so does the
// result {Z}, we can use that memory, and only need to allocate one scratch
// vector. If the final result ends up in the wrong bucket, we have to copy it
// to the correct one.
// - We don't have to keep track of the positions and sizes of the chunks,
// because we can deduce their precise placement from the iteration index.
//
// Example, assuming digit_t is 4 bits, fitting one decimal digit:
// Initial state:
// parts_: 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// multipliers_: 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
// After the first iteration of the outer loop:
// parts: 12 34 56 78 90 12 34 5
// multipliers: 100 *100 *100 *100 *100 *100 10
// After the second iteration:
// parts: 1234 5678 9012 345
// multipliers: 10000 *10000 1000
// After the third iteration:
// parts: 12345678 9012345
// multipliers: 10000000
// And then there's an obvious last iteration.
void ProcessorImpl::FromStringLarge(RWDigits Z,
FromStringAccumulator* accumulator) {
int num_parts = static_cast<int>(accumulator->heap_parts_.size());
DCHECK(num_parts >= 2); // NOLINT(readability/check)
DCHECK(Z.len() >= num_parts);
RWDigits parts(accumulator->heap_parts_.data(), num_parts);
Storage multipliers_storage(num_parts);
RWDigits multipliers(multipliers_storage.get(), num_parts);
RWDigits temp(Z, 0, num_parts);
// Unrolled and specialized first iteration: part_len == 1, so instead of
// Digits sub-vectors we have individual digit_t values, and the multipliers
// are known up front.
{
digit_t max_multiplier = accumulator->max_multiplier_;
digit_t last_multiplier = accumulator->last_multiplier_;
RWDigits new_parts = temp;
RWDigits new_multipliers = parts;
int i = 0;
for (; i + 1 < num_parts; i += 2) {
digit_t p_in = parts[i];
digit_t p_in2 = parts[i + 1];
digit_t m_in = max_multiplier;
digit_t m_in2 = i == num_parts - 2 ? last_multiplier : max_multiplier;
// p[j] = p[i] * m[i+1] + p[i+1]
digit_t p_high;
digit_t p_low = digit_mul(p_in, m_in2, &p_high);
digit_t carry;
new_parts[i] = digit_add2(p_low, p_in2, &carry);
new_parts[i + 1] = p_high + carry;
// m[j] = m[i] * m[i+1]
if (i > 0) {
if (i > 2 && m_in2 != last_multiplier) {
new_multipliers[i] = new_multipliers[i - 2];
new_multipliers[i + 1] = new_multipliers[i - 1];
} else {
digit_t m_high;
new_multipliers[i] = digit_mul(m_in, m_in2, &m_high);
new_multipliers[i + 1] = m_high;
}
}
}
// Trailing last part (if {num_parts} was odd).
if (i < num_parts) {
new_parts[i] = parts[i];
new_multipliers[i] = last_multiplier;
i += 2;
}
num_parts = i >> 1;
RWDigits new_temp = multipliers;
parts = new_parts;
multipliers = new_multipliers;
temp = new_temp;
AddWorkEstimate(num_parts);
}
int part_len = 2;
// Remaining iterations.
while (num_parts > 1) {
RWDigits new_parts = temp;
RWDigits new_multipliers = parts;
int new_part_len = part_len * 2;
int i = 0;
for (; i + 1 < num_parts; i += 2) {
int start = i * part_len;
Digits p_in(parts, start, part_len);
Digits p_in2(parts, start + part_len, part_len);
Digits m_in(multipliers, start, part_len);
Digits m_in2(multipliers, start + part_len, part_len);
RWDigits p_out(new_parts, start, new_part_len);
RWDigits m_out(new_multipliers, start, new_part_len);
// p[j] = p[i] * m[i+1] + p[i+1]
Multiply(p_out, p_in, m_in2);
if (should_terminate()) return;
digit_t overflow = AddAndReturnOverflow(p_out, p_in2);
DCHECK(overflow == 0); // NOLINT(readability/check)
USE(overflow);
// m[j] = m[i] * m[i+1]
if (i > 0) {
bool copied = false;
if (i > 2) {
int prev_start = (i - 2) * part_len;
Digits m_in_prev(multipliers, prev_start, part_len);
Digits m_in2_prev(multipliers, prev_start + part_len, part_len);
if (Compare(m_in, m_in_prev) == 0 &&
Compare(m_in2, m_in2_prev) == 0) {
copied = true;
Digits m_out_prev(new_multipliers, prev_start, new_part_len);
for (int k = 0; k < new_part_len; k++) m_out[k] = m_out_prev[k];
}
}
if (!copied) {
Multiply(m_out, m_in, m_in2);
if (should_terminate()) return;
}
}
}
// Trailing last part (if {num_parts} was odd).
if (i < num_parts) {
Digits p_in(parts, i * part_len, part_len);
Digits m_in(multipliers, i * part_len, part_len);
RWDigits p_out(new_parts, i * part_len, new_part_len);
RWDigits m_out(new_multipliers, i * part_len, new_part_len);
int k = 0;
for (; k < p_in.len(); k++) p_out[k] = p_in[k];
for (; k < p_out.len(); k++) p_out[k] = 0;
k = 0;
for (; k < m_in.len(); k++) m_out[k] = m_in[k];
for (; k < m_out.len(); k++) m_out[k] = 0;
i += 2;
}
num_parts = i >> 1;
part_len = new_part_len;
RWDigits new_temp = multipliers;
parts = new_parts;
multipliers = new_multipliers;
temp = new_temp;
}
// Copy the result to Z, if it doesn't happen to be there already.
if (parts.digits() != Z.digits()) {
int i = 0;
for (; i < parts.len(); i++) Z[i] = parts[i];
// Z might be bigger than we requested; be robust towards that.
for (; i < Z.len(); i++) Z[i] = 0;
}
}
void ProcessorImpl::FromString(RWDigits Z, FromStringAccumulator* accumulator) { void ProcessorImpl::FromString(RWDigits Z, FromStringAccumulator* accumulator) {
if (accumulator->inline_everything_) { if (accumulator->inline_everything_) {
int i = 0; int i = 0;
...@@ -57,8 +221,10 @@ void ProcessorImpl::FromString(RWDigits Z, FromStringAccumulator* accumulator) { ...@@ -57,8 +221,10 @@ void ProcessorImpl::FromString(RWDigits Z, FromStringAccumulator* accumulator) {
for (; i < Z.len(); i++) Z[i] = 0; for (; i < Z.len(); i++) Z[i] = 0;
} else if (accumulator->stack_parts_used_ == 0) { } else if (accumulator->stack_parts_used_ == 0) {
for (int i = 0; i < Z.len(); i++) Z[i] = 0; for (int i = 0; i < Z.len(); i++) Z[i] = 0;
} else { } else if (accumulator->ResultLength() < kFromStringLargeThreshold) {
FromStringClassic(Z, accumulator); FromStringClassic(Z, accumulator);
} else {
FromStringLarge(Z, accumulator);
} }
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include <cmath>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -28,12 +29,13 @@ int PrintHelp(char** argv) { ...@@ -28,12 +29,13 @@ int PrintHelp(char** argv) {
return 1; return 1;
} }
#define TESTS(V) \ #define TESTS(V) \
V(kBarrett, "barrett") \ V(kBarrett, "barrett") \
V(kBurnikel, "burnikel") \ V(kBurnikel, "burnikel") \
V(kFFT, "fft") \ V(kFFT, "fft") \
V(kKaratsuba, "karatsuba") \ V(kFromString, "fromstring") \
V(kToom, "toom") \ V(kKaratsuba, "karatsuba") \
V(kToom, "toom") \
V(kToString, "tostring") V(kToString, "tostring")
enum Operation { kNoOp, kList, kTest }; enum Operation { kNoOp, kList, kTest };
...@@ -86,7 +88,7 @@ class RNG { ...@@ -86,7 +88,7 @@ class RNG {
static constexpr int kCharsPerDigit = kDigitBits / 4; static constexpr int kCharsPerDigit = kDigitBits / 4;
static const char kConversionChars[] = "0123456789abcdef"; static const char kConversionChars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
std::string FormatHex(Digits X) { std::string FormatHex(Digits X) {
X.Normalize(); X.Normalize();
...@@ -173,6 +175,16 @@ class Runner { ...@@ -173,6 +175,16 @@ class Runner {
error_ = true; error_ = true;
} }
void AssertEquals(const char* input, int input_length, int radix,
Digits expected, Digits actual) {
if (Compare(expected, actual) == 0) return;
std::cerr << "Input: " << std::string(input, input_length) << "\n";
std::cerr << "Radix: " << radix << "\n";
std::cerr << "Expected: " << FormatHex(expected) << "\n";
std::cerr << "Actual: " << FormatHex(actual) << "\n";
error_ = true;
}
int RunTest() { int RunTest() {
int count = 0; int count = 0;
if (test_ == kBarrett) { if (test_ == kBarrett) {
...@@ -199,6 +211,10 @@ class Runner { ...@@ -199,6 +211,10 @@ class Runner {
for (int i = 0; i < runs_; i++) { for (int i = 0; i < runs_; i++) {
TestToString(&count); TestToString(&count);
} }
} else if (test_ == kFromString) {
for (int i = 0; i < runs_; i++) {
TestFromString(&count);
}
} else { } else {
DCHECK(false); // Unreachable. DCHECK(false); // Unreachable.
} }
...@@ -391,6 +407,33 @@ class Runner { ...@@ -391,6 +407,33 @@ class Runner {
} }
} }
void TestFromString(int* count) {
constexpr int kMaxDigits = 1 << 20; // Any large-enough value will do.
constexpr int kMin = kFromStringLargeThreshold / 2;
constexpr int kMax = kFromStringLargeThreshold * 2;
for (int size = kMin; size < kMax; size++) {
// To keep test execution times low, test one random radix every time.
// Valid range is 2 <= radix <= 36 (inclusive).
int radix = 2 + (rng_.NextUint64() % 35);
int num_chars = std::round(size * kDigitBits / std::log2(radix));
std::unique_ptr<char[]> chars(new char[num_chars]);
GenerateRandomString(chars.get(), num_chars, radix);
FromStringAccumulator accumulator(kMaxDigits);
FromStringAccumulator ref_accumulator(kMaxDigits);
const char* start = chars.get();
const char* end = chars.get() + num_chars;
accumulator.Parse(start, end, radix);
ref_accumulator.Parse(start, end, radix);
ScratchDigits result(accumulator.ResultLength());
ScratchDigits reference(ref_accumulator.ResultLength());
processor()->FromStringLarge(result, &accumulator);
processor()->FromStringClassic(reference, &ref_accumulator);
AssertEquals(start, num_chars, radix, result, reference);
if (error_) return;
(*count)++;
}
}
int ParseOptions(int argc, char** argv) { int ParseOptions(int argc, char** argv) {
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--list") == 0) { if (strcmp(argv[i], "--list") == 0) {
...@@ -447,6 +490,30 @@ class Runner { ...@@ -447,6 +490,30 @@ class Runner {
} }
} }
void GenerateRandomString(char* str, int len, int radix) {
DCHECK(2 <= radix && radix <= 36);
if (len == 0) return;
uint64_t random;
int available_bits = 0;
const int char_bits = BitLength(radix - 1);
const uint64_t char_mask = (1u << char_bits) - 1u;
for (int i = 0; i < len; i++) {
while (true) {
if (available_bits < char_bits) {
random = rng_.NextUint64();
available_bits = 64;
}
int next_char = static_cast<int>(random & char_mask);
random = random >> char_bits;
available_bits -= char_bits;
if (next_char >= radix) continue;
*str = kConversionChars[next_char];
str++;
break;
};
}
}
Operation op_{kNoOp}; Operation op_{kNoOp};
Test test_; Test test_;
bool error_{false}; bool error_{false};
......
...@@ -241,7 +241,7 @@ TEST(TerminateBigIntToString) { ...@@ -241,7 +241,7 @@ TEST(TerminateBigIntToString) {
TEST(TerminateBigIntFromString) { TEST(TerminateBigIntFromString) {
TestTerminatingSlowOperation( TestTerminatingSlowOperation(
"var a = '12344567890'.repeat(10000);\n" "var a = '12344567890'.repeat(100000);\n"
"terminate();\n" "terminate();\n"
"BigInt(a);\n" "BigInt(a);\n"
"fail();\n"); "fail();\n");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment