Commit 50d725f1 authored by Yuri Iozzelli's avatar Yuri Iozzelli Committed by V8 LUCI CQ

Implementation of the branch hinting proposal for WebAssembly.

See https://github.com/WebAssembly/branch-hinting for a description of
the proposal.

Change-Id: Ib6e980fc20aa750decabdeb9e281f502c9fe84ed
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/2784696
Commit-Queue: Jakob Kummerow <jkummerow@chromium.org>
Reviewed-by: 's avatarJakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#74569}
parent 2c096b53
......@@ -242,6 +242,7 @@ Yong Wang <ccyongwang@tencent.com>
Youfeng Hao <ajihyf@gmail.com>
Yu Yin <xwafish@gmail.com>
Yusif Khudhur <yusif.khudhur@gmail.com>
Yuri Iozzelli <yuri@leaningtech.com>
Zac Hansen <xaxxon@gmail.com>
Zeynep Cankara <zeynepcankara402@gmail.com>
Zhao Jiazhong <kyslie3100@gmail.com>
......
......@@ -1312,6 +1312,11 @@ Node* WasmGraphBuilder::BranchExpectFalse(Node* cond, Node** true_node,
return gasm_->Branch(cond, true_node, false_node, BranchHint::kFalse);
}
Node* WasmGraphBuilder::BranchExpectTrue(Node* cond, Node** true_node,
Node** false_node) {
return gasm_->Branch(cond, true_node, false_node, BranchHint::kTrue);
}
Node* WasmGraphBuilder::Select(Node *cond, Node* true_node,
Node* false_node, wasm::ValueType type) {
MachineOperatorBuilder* m = mcgraph()->machine();
......@@ -7761,7 +7766,7 @@ bool BuildGraphForWasmFunction(AccountingAllocator* allocator,
source_positions);
wasm::VoidResult graph_construction_result = wasm::BuildTFGraph(
allocator, env->enabled_features, env->module, &builder, detected,
func_body, loop_infos, node_origins);
func_body, loop_infos, node_origins, func_index);
if (graph_construction_result.failed()) {
if (FLAG_trace_wasm_compiler) {
StdoutStream{} << "Compilation failed: "
......
......@@ -281,6 +281,7 @@ class WasmGraphBuilder {
//-----------------------------------------------------------------------
Node* BranchNoHint(Node* cond, Node** true_node, Node** false_node);
Node* BranchExpectFalse(Node* cond, Node** true_node, Node** false_node);
Node* BranchExpectTrue(Node* cond, Node** true_node, Node** false_node);
void TrapIfTrue(wasm::TrapReason reason, Node* cond,
wasm::WasmCodePosition position);
......
// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef V8_WASM_BRANCH_HINT_MAP_H_
#define V8_WASM_BRANCH_HINT_MAP_H_
#include <unordered_map>
#include "src/base/macros.h"
namespace v8 {
namespace internal {
namespace wasm {
enum class WasmBranchHint : uint8_t {
kNoHint = 0,
kUnlikely = 1,
kLikely = 2,
};
class V8_EXPORT_PRIVATE BranchHintMap {
public:
void insert(uint32_t offset, WasmBranchHint hint) {
map_.emplace(offset, hint);
}
WasmBranchHint GetHintFor(uint32_t offset) const {
auto it = map_.find(offset);
if (it == map_.end()) {
return WasmBranchHint::kNoHint;
}
return it->second;
}
private:
std::unordered_map<uint32_t, WasmBranchHint> map_;
};
using BranchHintInfo = std::unordered_map<uint32_t, BranchHintMap>;
} // namespace wasm
} // namespace internal
} // namespace v8
#endif // V8_WASM_BRANCH_HINT_MAP_H_
......@@ -2328,7 +2328,12 @@ class WasmFullDecoder : public WasmDecoder<validate> {
}
}
inline uint32_t pc_relative_offset() const {
return this->pc_offset() - first_instruction_offset;
}
private:
uint32_t first_instruction_offset = 0;
Interface interface_;
// The value stack, stored as individual pointers for maximum performance.
......@@ -3479,6 +3484,7 @@ class WasmFullDecoder : public WasmDecoder<validate> {
CALL_INTERFACE_IF_OK_AND_REACHABLE(StartFunctionBody, c);
}
first_instruction_offset = this->pc_offset();
// Decode the function body.
while (this->pc_ < this->end_) {
// Most operations only grow the stack by at least one element (unary and
......
......@@ -9,6 +9,7 @@
#include "src/handles/handles.h"
#include "src/objects/objects-inl.h"
#include "src/utils/ostreams.h"
#include "src/wasm/branch-hint-map.h"
#include "src/wasm/decoder.h"
#include "src/wasm/function-body-decoder-impl.h"
#include "src/wasm/function-body-decoder.h"
......@@ -108,10 +109,18 @@ class WasmGraphBuildingInterface {
: ControlBase(std::forward<Args>(args)...) {}
};
explicit WasmGraphBuildingInterface(compiler::WasmGraphBuilder* builder)
: builder_(builder) {}
explicit WasmGraphBuildingInterface(compiler::WasmGraphBuilder* builder,
int func_index)
: builder_(builder), func_index_(func_index) {}
void StartFunction(FullDecoder* decoder) {
// Get the branch hints map for this function (if available)
if (decoder->module_) {
auto branch_hints_it = decoder->module_->branch_hints.find(func_index_);
if (branch_hints_it != decoder->module_->branch_hints.end()) {
branch_hints_ = &branch_hints_it->second;
}
}
// The first '+ 1' is needed by TF Start node, the second '+ 1' is for the
// instance parameter.
builder_->Start(static_cast<int>(decoder->sig_->parameter_count() + 1 + 1));
......@@ -243,7 +252,21 @@ class WasmGraphBuildingInterface {
void If(FullDecoder* decoder, const Value& cond, Control* if_block) {
TFNode* if_true = nullptr;
TFNode* if_false = nullptr;
builder_->BranchNoHint(cond.node, &if_true, &if_false);
WasmBranchHint hint = WasmBranchHint::kNoHint;
if (branch_hints_) {
hint = branch_hints_->GetHintFor(decoder->pc_relative_offset());
}
switch (hint) {
case WasmBranchHint::kNoHint:
builder_->BranchNoHint(cond.node, &if_true, &if_false);
break;
case WasmBranchHint::kUnlikely:
builder_->BranchExpectFalse(cond.node, &if_true, &if_false);
break;
case WasmBranchHint::kLikely:
builder_->BranchExpectTrue(cond.node, &if_true, &if_false);
break;
}
SsaEnv* merge_env = ssa_env_;
SsaEnv* false_env = Split(decoder->zone(), ssa_env_);
false_env->control = if_false;
......@@ -476,7 +499,21 @@ class WasmGraphBuildingInterface {
SsaEnv* fenv = ssa_env_;
SsaEnv* tenv = Split(decoder->zone(), fenv);
fenv->SetNotMerged();
builder_->BranchNoHint(cond.node, &tenv->control, &fenv->control);
WasmBranchHint hint = WasmBranchHint::kNoHint;
if (branch_hints_) {
hint = branch_hints_->GetHintFor(decoder->pc_relative_offset());
}
switch (hint) {
case WasmBranchHint::kNoHint:
builder_->BranchNoHint(cond.node, &tenv->control, &fenv->control);
break;
case WasmBranchHint::kUnlikely:
builder_->BranchExpectFalse(cond.node, &tenv->control, &fenv->control);
break;
case WasmBranchHint::kLikely:
builder_->BranchExpectTrue(cond.node, &tenv->control, &fenv->control);
break;
}
builder_->SetControl(fenv->control);
SetEnv(tenv);
BrOrRet(decoder, depth, 1);
......@@ -1067,6 +1104,8 @@ class WasmGraphBuildingInterface {
private:
SsaEnv* ssa_env_ = nullptr;
compiler::WasmGraphBuilder* builder_;
int func_index_;
const BranchHintMap* branch_hints_ = nullptr;
// Tracks loop data for loop unrolling.
std::vector<compiler::WasmLoopInfo> loop_infos_;
......@@ -1498,10 +1537,11 @@ DecodeResult BuildTFGraph(AccountingAllocator* allocator,
compiler::WasmGraphBuilder* builder,
WasmFeatures* detected, const FunctionBody& body,
std::vector<compiler::WasmLoopInfo>* loop_infos,
compiler::NodeOriginTable* node_origins) {
compiler::NodeOriginTable* node_origins,
int func_index) {
Zone zone(allocator, ZONE_NAME);
WasmFullDecoder<Decoder::kFullValidation, WasmGraphBuildingInterface> decoder(
&zone, module, enabled, detected, body, builder);
&zone, module, enabled, detected, body, builder, func_index);
if (node_origins) {
builder->AddBytecodePositionDecorator(node_origins, &decoder);
}
......
......@@ -32,7 +32,7 @@ BuildTFGraph(AccountingAllocator* allocator, const WasmFeatures& enabled,
const WasmModule* module, compiler::WasmGraphBuilder* builder,
WasmFeatures* detected, const FunctionBody& body,
std::vector<compiler::WasmLoopInfo>* loop_infos,
compiler::NodeOriginTable* node_origins);
compiler::NodeOriginTable* node_origins, int func_index);
} // namespace wasm
} // namespace internal
......
......@@ -34,6 +34,7 @@ namespace {
constexpr char kNameString[] = "name";
constexpr char kSourceMappingURLString[] = "sourceMappingURL";
constexpr char kCompilationHintsString[] = "compilationHints";
constexpr char kBranchHintsString[] = "branchHints";
constexpr char kDebugInfoString[] = ".debug_info";
constexpr char kExternalDebugInfoString[] = "external_debug_info";
......@@ -95,6 +96,8 @@ const char* SectionName(SectionCode code) {
return kExternalDebugInfoString;
case kCompilationHintsSectionCode:
return kCompilationHintsString;
case kBranchHintsSectionCode:
return kBranchHintsString;
default:
return "<unknown>";
}
......@@ -144,6 +147,7 @@ SectionCode IdentifyUnknownSectionInternal(Decoder* decoder) {
{StaticCharVector(kNameString), kNameSectionCode},
{StaticCharVector(kSourceMappingURLString), kSourceMappingURLSectionCode},
{StaticCharVector(kCompilationHintsString), kCompilationHintsSectionCode},
{StaticCharVector(kBranchHintsString), kBranchHintsSectionCode},
{StaticCharVector(kDebugInfoString), kDebugInfoSectionCode},
{StaticCharVector(kExternalDebugInfoString),
kExternalDebugInfoSectionCode}};
......@@ -432,6 +436,13 @@ class ModuleDecoderImpl : public Decoder {
// first occurrence after function section and before code section are
// ignored.
break;
case kBranchHintsSectionCode:
// TODO(yuri): report out of place branch hints section as a
// warning.
// Be lenient with placement of compilation hints section. All except
// first occurrence after function section and before code section are
// ignored.
break;
default:
next_ordered_section_ = section_code + 1;
break;
......@@ -498,6 +509,15 @@ class ModuleDecoderImpl : public Decoder {
consume_bytes(static_cast<uint32_t>(end_ - start_), nullptr);
}
break;
case kBranchHintsSectionCode:
if (enabled_features_.has_branch_hinting()) {
DecodeBranchHintsSection();
} else {
// Ignore this section when feature was disabled. It is an optional
// custom section anyways.
consume_bytes(static_cast<uint32_t>(end_ - start_), nullptr);
}
break;
case kDataCountSectionCode:
DecodeDataCountSection();
break;
......@@ -1149,6 +1169,82 @@ class ModuleDecoderImpl : public Decoder {
// consume_bytes(static_cast<uint32_t>(end_ - start_), nullptr);
}
void DecodeBranchHintsSection() {
TRACE("DecodeBranchHints module+%d\n", static_cast<int>(pc_ - start_));
if (!has_seen_unordered_section(kBranchHintsSectionCode)) {
set_seen_unordered_section(kBranchHintsSectionCode);
// Use an inner decoder so that errors don't fail the outer decoder.
Decoder inner(start_, pc_, end_, buffer_offset_);
BranchHintInfo branch_hints;
uint32_t func_count = inner.consume_u32v("number of functions");
// Keep track of the previous function index to validate the ordering
int64_t last_func_idx = -1;
for (uint32_t i = 0; i < func_count; i++) {
uint32_t func_idx = inner.consume_u32v("function index");
if (int64_t(func_idx) <= last_func_idx) {
inner.errorf("Invalid function index: %d", func_idx);
break;
}
last_func_idx = func_idx;
uint8_t reserved = inner.consume_u8("reserved byte");
if (reserved != 0x0) {
inner.errorf("Invalid reserved byte: %#x", reserved);
break;
}
uint32_t num_hints = inner.consume_u32v("number of hints");
BranchHintMap func_branch_hints;
TRACE("DecodeBranchHints[%d] module+%d\n", func_idx,
static_cast<int>(inner.pc() - inner.start()));
// Keep track of the previous branch offset to validate the ordering
int64_t last_br_off = -1;
for (uint32_t j = 0; j < num_hints; ++j) {
uint32_t br_dir = inner.consume_u32v("branch direction");
uint32_t br_off = inner.consume_u32v("branch instruction offset");
if (int64_t(br_off) <= last_br_off) {
inner.errorf("Invalid branch offset: %d", br_off);
break;
}
last_br_off = br_off;
TRACE("DecodeBranchHints[%d][%d] module+%d\n", func_idx, br_off,
static_cast<int>(inner.pc() - inner.start()));
WasmBranchHint hint;
switch (br_dir) {
case 0:
hint = WasmBranchHint::kUnlikely;
break;
case 1:
hint = WasmBranchHint::kLikely;
break;
default:
hint = WasmBranchHint::kNoHint;
inner.errorf(inner.pc(), "Invalid branch hint %#x", br_dir);
break;
}
if (!inner.ok()) {
break;
}
func_branch_hints.insert(br_off, hint);
}
if (!inner.ok()) {
break;
}
branch_hints.emplace(func_idx, std::move(func_branch_hints));
}
// Extra unexpected bytes are an error.
if (inner.more()) {
inner.errorf("Unexpected extra bytes: %d\n",
static_cast<int>(inner.pc() - inner.start()));
}
// If everything went well, accept the hints for the module.
if (inner.ok()) {
module_->branch_hints = std::move(branch_hints);
}
}
// Skip the whole branch hints section in the outer decoder.
consume_bytes(static_cast<uint32_t>(end_ - start_), nullptr);
}
void DecodeDataCountSection() {
module_->num_declared_data_segments =
consume_count("data segments count", kV8MaxWasmDataSegments);
......
......@@ -101,10 +101,11 @@ enum SectionCode : int8_t {
kDebugInfoSectionCode, // DWARF section .debug_info
kExternalDebugInfoSectionCode, // Section encoding the external symbol path
kCompilationHintsSectionCode, // Compilation hints section
kBranchHintsSectionCode, // Branch hints section
// Helper values
kFirstSectionInModule = kTypeSectionCode,
kLastKnownModuleSection = kCompilationHintsSectionCode,
kLastKnownModuleSection = kBranchHintsSectionCode,
kFirstUnorderedSection = kDataCountSectionCode,
};
......
......@@ -37,7 +37,12 @@
/* Relaxed SIMD proposal. */ \
/* https://github.com/WebAssembly/relaxed-simd */ \
/* V8 side owner: zhin */ \
V(relaxed_simd, "relaxed simd", false)
V(relaxed_simd, "relaxed simd", false) \
\
/* Branch Hinting proposal. */ \
/* https://github.com/WebAssembly/branch-hinting */ \
/* V8 side owner: jkummerow */ \
V(branch_hinting, "branch hinting", false)
// #############################################################################
// Staged features (disabled by default, but enabled via --wasm-staging (also
......
......@@ -16,6 +16,7 @@
#include "src/common/globals.h"
#include "src/handles/handles.h"
#include "src/utils/vector.h"
#include "src/wasm/branch-hint-map.h"
#include "src/wasm/signature-map.h"
#include "src/wasm/struct-types.h"
#include "src/wasm/wasm-constants.h"
......@@ -338,6 +339,7 @@ struct V8_EXPORT_PRIVATE WasmModule {
std::vector<WasmException> exceptions;
std::vector<WasmElemSegment> elem_segments;
std::vector<WasmCompilationHint> compilation_hints;
BranchHintInfo branch_hints;
SignatureMap signature_map; // canonicalizing map for signature indexes.
ModuleOrigin origin = kWasmOrigin; // origin of the module
......
......@@ -363,7 +363,7 @@ void TestBuildingGraphWithBuilder(compiler::WasmGraphBuilder* builder,
std::vector<compiler::WasmLoopInfo> loops;
DecodeResult result =
BuildTFGraph(zone->allocator(), WasmFeatures::All(), nullptr, builder,
&unused_detected_features, body, &loops, nullptr);
&unused_detected_features, body, &loops, nullptr, 0);
if (result.failed()) {
#ifdef DEBUG
if (!FLAG_trace_wasm_decoder) {
......@@ -371,7 +371,7 @@ void TestBuildingGraphWithBuilder(compiler::WasmGraphBuilder* builder,
FLAG_trace_wasm_decoder = true;
result =
BuildTFGraph(zone->allocator(), WasmFeatures::All(), nullptr, builder,
&unused_detected_features, body, &loops, nullptr);
&unused_detected_features, body, &loops, nullptr, 0);
}
#endif
......
......@@ -6,6 +6,7 @@
#include "src/handles/handles.h"
#include "src/objects/objects-inl.h"
#include "src/wasm/branch-hint-map.h"
#include "src/wasm/wasm-engine.h"
#include "src/wasm/wasm-features.h"
#include "src/wasm/wasm-limits.h"
......@@ -68,6 +69,11 @@ namespace module_decoder_unittest {
'H', 'i', 'n', 't', 's'), \
ADD_COUNT(__VA_ARGS__))
#define SECTION_BRANCH_HINTS(...) \
SECTION(Unknown, \
ADD_COUNT('b', 'r', 'a', 'n', 'c', 'h', 'H', 'i', 'n', 't', 's'), \
__VA_ARGS__)
#define FAIL_IF_NO_EXPERIMENTAL_EH(data) \
do { \
ModuleResult result = DecodeModule((data), (data) + sizeof((data))); \
......@@ -2098,6 +2104,30 @@ TEST_F(WasmModuleVerifyTest, TieringCompilationHints) {
result.value()->compilation_hints[2].top_tier);
}
TEST_F(WasmModuleVerifyTest, BranchHinting) {
WASM_FEATURE_SCOPE(branch_hinting);
static const byte data[] = {
TYPE_SECTION(1, SIG_ENTRY_v_v), FUNCTION_SECTION(2, 0, 0),
SECTION_BRANCH_HINTS(ENTRY_COUNT(2), 0 /*func_index*/, 0 /*reserved*/,
ENTRY_COUNT(1), 1 /*likely*/, 2 /* if offset*/,
1 /*func_index*/, 0 /*reserved*/, ENTRY_COUNT(1),
0 /*unlikely*/, 4 /* br_if offset*/),
SECTION(Code, ENTRY_COUNT(2),
ADD_COUNT(0, /*no locals*/
WASM_IF(WASM_I32V_1(1), WASM_NOP), WASM_END),
ADD_COUNT(0, /*no locals*/
WASM_BLOCK(WASM_BR_IF(0, WASM_I32V_1(1))), WASM_END))};
ModuleResult result = DecodeModule(data, data + sizeof(data));
EXPECT_OK(result);
EXPECT_EQ(2u, result.value()->branch_hints.size());
EXPECT_EQ(WasmBranchHint::kLikely,
result.value()->branch_hints[0].GetHintFor(2));
EXPECT_EQ(WasmBranchHint::kUnlikely,
result.value()->branch_hints[1].GetHintFor(4));
}
class WasmSignatureDecodeTest : public TestWithZone {
public:
WasmFeatures enabled_features_ = WasmFeatures::None();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment