Commit 24fa74e7 authored by Manos Koukoutos's avatar Manos Koukoutos Committed by Commit Bot

[wasm-gc] Extend type checks to allow picking branch hints

Bug: v8:7748
Change-Id: I32c87d4e3b98ab44699c1b7bf952aedef3e27002
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/2704667
Commit-Queue: Manos Koukoutos <manoskouk@chromium.org>
Reviewed-by: 's avatarJakob Kummerow <jkummerow@chromium.org>
Reviewed-by: 's avatarJakob Gruber <jgruber@chromium.org>
Cr-Commit-Position: refs/heads/master@{#72890}
parent 8b15e84b
......@@ -370,6 +370,22 @@ class V8_EXPORT_PRIVATE GraphAssembler {
BranchHint hint, Vars...);
// Control helpers.
// {GotoIf(c, l, h)} is equivalent to {BranchWithHint(c, l, templ, h);
// Bind(templ)}.
template <typename... Vars>
void GotoIf(Node* condition, GraphAssemblerLabel<sizeof...(Vars)>* label,
BranchHint hint, Vars...);
// {GotoIfNot(c, l, h)} is equivalent to {BranchWithHint(c, templ, l, h);
// Bind(templ)}.
// The branch hint refers to the expected outcome of the provided condition,
// so {GotoIfNot(..., BranchHint::kTrue)} means "optimize for the case where
// the branch is *not* taken".
template <typename... Vars>
void GotoIfNot(Node* condition, GraphAssemblerLabel<sizeof...(Vars)>* label,
BranchHint hint, Vars...);
// {GotoIf(c, l)} is equivalent to {Branch(c, l, templ);Bind(templ)}.
template <typename... Vars>
void GotoIf(Node* condition, GraphAssemblerLabel<sizeof...(Vars)>* label,
......@@ -748,9 +764,7 @@ void GraphAssembler::Goto(GraphAssemblerLabel<sizeof...(Vars)>* label,
template <typename... Vars>
void GraphAssembler::GotoIf(Node* condition,
GraphAssemblerLabel<sizeof...(Vars)>* label,
Vars... vars) {
BranchHint hint =
label->IsDeferred() ? BranchHint::kFalse : BranchHint::kNone;
BranchHint hint, Vars... vars) {
Node* branch = graph()->NewNode(common()->Branch(hint), condition, control());
control_ = graph()->NewNode(common()->IfTrue(), branch);
......@@ -763,8 +777,7 @@ void GraphAssembler::GotoIf(Node* condition,
template <typename... Vars>
void GraphAssembler::GotoIfNot(Node* condition,
GraphAssemblerLabel<sizeof...(Vars)>* label,
Vars... vars) {
BranchHint hint = label->IsDeferred() ? BranchHint::kTrue : BranchHint::kNone;
BranchHint hint, Vars... vars) {
Node* branch = graph()->NewNode(common()->Branch(hint), condition, control());
control_ = graph()->NewNode(common()->IfFalse(), branch);
......@@ -774,6 +787,23 @@ void GraphAssembler::GotoIfNot(Node* condition,
control_ = AddNode(graph()->NewNode(common()->IfTrue(), branch));
}
template <typename... Vars>
void GraphAssembler::GotoIf(Node* condition,
GraphAssemblerLabel<sizeof...(Vars)>* label,
Vars... vars) {
BranchHint hint =
label->IsDeferred() ? BranchHint::kFalse : BranchHint::kNone;
return GotoIf(condition, label, hint, vars...);
}
template <typename... Vars>
void GraphAssembler::GotoIfNot(Node* condition,
GraphAssemblerLabel<sizeof...(Vars)>* label,
Vars... vars) {
BranchHint hint = label->IsDeferred() ? BranchHint::kTrue : BranchHint::kNone;
return GotoIfNot(condition, label, hint, vars...);
}
template <typename... Args>
TNode<Object> GraphAssembler::Call(const CallDescriptor* call_descriptor,
Node* first_arg, Args... args) {
......
......@@ -5754,29 +5754,31 @@ void AssertFalse(MachineGraph* mcgraph, GraphAssembler* gasm, Node* condition) {
WasmGraphBuilder::Callbacks WasmGraphBuilder::TestCallbacks(
GraphAssemblerLabel<1>* label) {
return {// succeed_if
[=](Node* condition) -> void {
gasm_->GotoIf(condition, label, gasm_->Int32Constant(1));
[=](Node* condition, BranchHint hint) -> void {
gasm_->GotoIf(condition, label, hint, gasm_->Int32Constant(1));
},
// fail_if
[=](Node* condition) -> void {
gasm_->GotoIf(condition, label, gasm_->Int32Constant(0));
[=](Node* condition, BranchHint hint) -> void {
gasm_->GotoIf(condition, label, hint, gasm_->Int32Constant(0));
},
// fail_if_not
[=](Node* condition) -> void {
gasm_->GotoIfNot(condition, label, gasm_->Int32Constant(0));
[=](Node* condition, BranchHint hint) -> void {
gasm_->GotoIfNot(condition, label, hint, gasm_->Int32Constant(0));
}};
}
WasmGraphBuilder::Callbacks WasmGraphBuilder::CastCallbacks(
GraphAssemblerLabel<0>* label, wasm::WasmCodePosition position) {
return {// succeed_if
[=](Node* condition) -> void { gasm_->GotoIf(condition, label); },
[=](Node* condition, BranchHint hint) -> void {
gasm_->GotoIf(condition, label, hint);
},
// fail_if
[=](Node* condition) -> void {
[=](Node* condition, BranchHint hint) -> void {
TrapIfTrue(wasm::kTrapIllegalCast, condition, position);
},
// fail_if_not
[=](Node* condition) -> void {
[=](Node* condition, BranchHint hint) -> void {
TrapIfFalse(wasm::kTrapIllegalCast, condition, position);
}};
}
......@@ -5786,30 +5788,27 @@ WasmGraphBuilder::Callbacks WasmGraphBuilder::BranchCallbacks(
SmallNodeVector& match_controls, SmallNodeVector& match_effects) {
return {
// succeed_if
[&](Node* condition) -> void {
Node* branch =
graph()->NewNode(mcgraph()->common()->Branch(BranchHint::kTrue),
condition, control());
[&](Node* condition, BranchHint hint) -> void {
Node* branch = graph()->NewNode(mcgraph()->common()->Branch(hint),
condition, control());
match_controls.emplace_back(
graph()->NewNode(mcgraph()->common()->IfTrue(), branch));
match_effects.emplace_back(effect());
SetControl(graph()->NewNode(mcgraph()->common()->IfFalse(), branch));
},
// fail_if
[&](Node* condition) -> void {
Node* branch =
graph()->NewNode(mcgraph()->common()->Branch(BranchHint::kFalse),
condition, control());
[&](Node* condition, BranchHint hint) -> void {
Node* branch = graph()->NewNode(mcgraph()->common()->Branch(hint),
condition, control());
no_match_controls.emplace_back(
graph()->NewNode(mcgraph()->common()->IfTrue(), branch));
no_match_effects.emplace_back(effect());
SetControl(graph()->NewNode(mcgraph()->common()->IfFalse(), branch));
},
// fail_if_not
[&](Node* condition) -> void {
Node* branch =
graph()->NewNode(mcgraph()->common()->Branch(BranchHint::kTrue),
condition, control());
[&](Node* condition, BranchHint hint) -> void {
Node* branch = graph()->NewNode(mcgraph()->common()->Branch(hint),
condition, control());
no_match_controls.emplace_back(
graph()->NewNode(mcgraph()->common()->IfFalse(), branch));
no_match_effects.emplace_back(effect());
......@@ -5821,8 +5820,8 @@ void WasmGraphBuilder::TypeCheck(
Node* object, Node* rtt, WasmGraphBuilder::ObjectReferenceKnowledge config,
bool null_succeeds, Callbacks callbacks) {
if (config.object_can_be_null) {
(null_succeeds ? callbacks.succeed_if
: callbacks.fail_if)(gasm_->WordEqual(object, RefNull()));
(null_succeeds ? callbacks.succeed_if : callbacks.fail_if)(
gasm_->WordEqual(object, RefNull()), BranchHint::kFalse);
}
Node* map = gasm_->LoadMap(object);
......@@ -5830,13 +5829,13 @@ void WasmGraphBuilder::TypeCheck(
if (config.reference_kind == kFunction) {
// Currently, the only way for a function to match an rtt is if its map
// is equal to that rtt.
callbacks.fail_if_not(gasm_->TaggedEqual(map, rtt));
callbacks.fail_if_not(gasm_->TaggedEqual(map, rtt), BranchHint::kTrue);
return;
}
DCHECK(config.reference_kind == kArrayOrStruct);
callbacks.succeed_if(gasm_->TaggedEqual(map, rtt));
callbacks.succeed_if(gasm_->TaggedEqual(map, rtt), BranchHint::kTrue);
Node* type_info = gasm_->LoadWasmTypeInfo(map);
Node* supertypes = gasm_->LoadSupertypes(type_info);
......@@ -5847,30 +5846,33 @@ void WasmGraphBuilder::TypeCheck(
? gasm_->Int32Constant(config.rtt_depth)
: BuildChangeSmiToInt32(gasm_->LoadFixedArrayLengthAsSmi(
gasm_->LoadSupertypes(gasm_->LoadWasmTypeInfo(rtt))));
callbacks.fail_if_not(gasm_->Uint32LessThan(rtt_depth, supertypes_length));
callbacks.fail_if_not(gasm_->Uint32LessThan(rtt_depth, supertypes_length),
BranchHint::kTrue);
Node* maybe_match = gasm_->LoadFixedArrayElement(
supertypes, rtt_depth, MachineType::TaggedPointer());
callbacks.fail_if_not(gasm_->TaggedEqual(maybe_match, rtt));
callbacks.fail_if_not(gasm_->TaggedEqual(maybe_match, rtt),
BranchHint::kTrue);
}
void WasmGraphBuilder::DataCheck(Node* object, bool object_can_be_null,
Callbacks callbacks) {
if (object_can_be_null) {
callbacks.fail_if(gasm_->WordEqual(object, RefNull()));
callbacks.fail_if(gasm_->WordEqual(object, RefNull()), BranchHint::kFalse);
}
callbacks.fail_if(gasm_->IsI31(object));
callbacks.fail_if(gasm_->IsI31(object), BranchHint::kFalse);
Node* map = gasm_->LoadMap(object);
callbacks.fail_if_not(gasm_->IsDataRefMap(map));
callbacks.fail_if_not(gasm_->IsDataRefMap(map), BranchHint::kTrue);
}
void WasmGraphBuilder::FuncCheck(Node* object, bool object_can_be_null,
Callbacks callbacks) {
if (object_can_be_null) {
callbacks.fail_if(gasm_->WordEqual(object, RefNull()));
callbacks.fail_if(gasm_->WordEqual(object, RefNull()), BranchHint::kFalse);
}
callbacks.fail_if(gasm_->IsI31(object));
callbacks.fail_if_not(gasm_->HasInstanceType(object, JS_FUNCTION_TYPE));
callbacks.fail_if(gasm_->IsI31(object), BranchHint::kFalse);
callbacks.fail_if_not(gasm_->HasInstanceType(object, JS_FUNCTION_TYPE),
BranchHint::kTrue);
}
Node* WasmGraphBuilder::BrOnCastAbs(
......
......@@ -39,6 +39,7 @@ enum class TrapId : uint32_t;
struct Int64LoweringSpecialCase;
template <size_t VarCount>
class GraphAssemblerLabel;
enum class BranchHint : uint8_t;
} // namespace compiler
namespace wasm {
......@@ -621,11 +622,11 @@ class WasmGraphBuilder {
// generates {index > max ? Smi(max) : Smi(index)}
Node* BuildConvertUint32ToSmiWithSaturation(Node* index, uint32_t maxval);
using NodeConsumer = std::function<void(Node*)>;
using BranchBuilder = std::function<void(Node*, BranchHint)>;
struct Callbacks {
NodeConsumer succeed_if;
NodeConsumer fail_if;
NodeConsumer fail_if_not;
BranchBuilder succeed_if;
BranchBuilder fail_if;
BranchBuilder fail_if_not;
};
// This type is used to collect control/effect nodes we need to merge at the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment