Commit b927dc15 authored by Manos Koukoutos's avatar Manos Koukoutos Committed by V8 LUCI CQ

[wasm][turbofan] Store real signature on call nodes for inlining

In each wasm CallDescriptor, we store the signature of the call based on
the real parameters passed to the call. This signature is more precise
than the formal function signature. We use this signature in inlining
to enable more optimizations.

Changes:
- Add wasm_sig_ field to CallDescriptor.
- Construct the real signature in {DoCall} and {DoReturnCall} in
  graph-builder-interface, and pass it to all call-related functions in
  WasmGraphBuilder.
- Update {ReplaceTypeInCallDescriptorWith} to use ValueType over
  MachineType. Construct the updated function signature.
- In wasm-inlining, kill the Call node after inlining.
- Add two tests.

Bug: v8:11510
Change-Id: Ica711b6b4d83945ecb7201be26577eab7db3c060
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3270539Reviewed-by: 's avatarJakob Kummerow <jkummerow@chromium.org>
Reviewed-by: 's avatarNico Hartmann <nicohartmann@chromium.org>
Commit-Queue: Manos Koukoutos <manoskouk@chromium.org>
Cr-Commit-Position: refs/heads/main@{#77889}
parent 725f92cb
......@@ -517,6 +517,9 @@ CallDescriptor* Linkage::GetStubCallDescriptor(
CallDescriptor::kCanUseRoots | flags, // flags
descriptor.DebugName(), // debug name
descriptor.GetStackArgumentOrder(), // stack order
#if V8_ENABLE_WEBASSEMBLY
nullptr, // wasm function signature
#endif
allocatable_registers);
}
......
......@@ -255,6 +255,9 @@ class V8_EXPORT_PRIVATE CallDescriptor final
RegList callee_saved_fp_registers, Flags flags,
const char* debug_name = "",
StackArgumentOrder stack_order = StackArgumentOrder::kDefault,
#if V8_ENABLE_WEBASSEMBLY
const wasm::FunctionSig* wasm_sig = nullptr,
#endif
const RegList allocatable_registers = 0,
size_t return_slot_count = 0)
: kind_(kind),
......@@ -269,7 +272,11 @@ class V8_EXPORT_PRIVATE CallDescriptor final
allocatable_registers_(allocatable_registers),
flags_(flags),
stack_order_(stack_order),
debug_name_(debug_name) {}
#if V8_ENABLE_WEBASSEMBLY
wasm_sig_(wasm_sig),
#endif
debug_name_(debug_name) {
}
CallDescriptor(const CallDescriptor&) = delete;
CallDescriptor& operator=(const CallDescriptor&) = delete;
......@@ -292,6 +299,9 @@ class V8_EXPORT_PRIVATE CallDescriptor final
// Returns {true} if this descriptor is a call to a Wasm C API function.
bool IsWasmCapiFunction() const { return kind_ == kCallWasmCapiFunction; }
// Returns the wasm signature for this call based on the real parameter types.
const wasm::FunctionSig* wasm_sig() const { return wasm_sig_; }
#endif // V8_ENABLE_WEBASSEMBLY
bool RequiresFrameAsIncoming() const {
......@@ -453,6 +463,9 @@ class V8_EXPORT_PRIVATE CallDescriptor final
const RegList allocatable_registers_;
const Flags flags_;
const StackArgumentOrder stack_order_;
#if V8_ENABLE_WEBASSEMBLY
const wasm::FunctionSig* wasm_sig_;
#endif
const char* const debug_name_;
mutable base::Optional<size_t> gp_param_count_;
......
This diff is collapsed.
......@@ -333,25 +333,32 @@ class WasmGraphBuilder {
void Trap(wasm::TrapReason reason, wasm::WasmCodePosition position);
Node* CallDirect(uint32_t index, base::Vector<Node*> args,
base::Vector<Node*> rets, wasm::WasmCodePosition position);
// In all six call-related public functions, we pass a signature based on the
// real arguments for this call. This signature gets stored in the Call node
// and will later help us generate better code if this call gets inlined.
Node* CallDirect(uint32_t index, wasm::FunctionSig* real_sig,
base::Vector<Node*> args, base::Vector<Node*> rets,
wasm::WasmCodePosition position);
Node* CallIndirect(uint32_t table_index, uint32_t sig_index,
base::Vector<Node*> args, base::Vector<Node*> rets,
wasm::WasmCodePosition position);
Node* CallRef(const wasm::FunctionSig* sig, base::Vector<Node*> args,
wasm::FunctionSig* real_sig, base::Vector<Node*> args,
base::Vector<Node*> rets, wasm::WasmCodePosition position);
Node* CallRef(const wasm::FunctionSig* real_sig, base::Vector<Node*> args,
base::Vector<Node*> rets, CheckForNull null_check,
wasm::WasmCodePosition position);
void CompareToExternalFunctionAtIndex(Node* func_ref, uint32_t function_index,
Node** success_control,
Node** failure_control);
Node* ReturnCall(uint32_t index, base::Vector<Node*> args,
wasm::WasmCodePosition position);
Node* ReturnCall(uint32_t index, const wasm::FunctionSig* real_sig,
base::Vector<Node*> args, wasm::WasmCodePosition position);
Node* ReturnCallIndirect(uint32_t table_index, uint32_t sig_index,
wasm::FunctionSig* real_sig,
base::Vector<Node*> args,
wasm::WasmCodePosition position);
Node* ReturnCallRef(const wasm::FunctionSig* sig, base::Vector<Node*> args,
CheckForNull null_check, wasm::WasmCodePosition position);
Node* ReturnCallRef(const wasm::FunctionSig* real_sig,
base::Vector<Node*> args, CheckForNull null_check,
wasm::WasmCodePosition position);
void CompareToExternalFunctionAtIndex(Node* func_ref, uint32_t function_index,
Node** success_control,
Node** failure_control);
void BrOnNull(Node* ref_object, Node** non_null_node, Node** null_node);
......@@ -530,6 +537,7 @@ class WasmGraphBuilder {
MachineGraph* mcgraph() { return mcgraph_; }
Graph* graph();
Zone* graph_zone();
void AddBytecodePositionDecorator(NodeOriginTable* node_origins,
wasm::Decoder* decoder);
......@@ -588,7 +596,8 @@ class WasmGraphBuilder {
Node** ift_sig_ids, Node** ift_targets,
Node** ift_instances);
Node* BuildIndirectCall(uint32_t table_index, uint32_t sig_index,
base::Vector<Node*> args, base::Vector<Node*> rets,
wasm::FunctionSig* real_sig, base::Vector<Node*> args,
base::Vector<Node*> rets,
wasm::WasmCodePosition position,
IsReturnCall continuation);
Node* BuildWasmCall(const wasm::FunctionSig* sig, base::Vector<Node*> args,
......@@ -606,9 +615,9 @@ class WasmGraphBuilder {
base::Vector<Node*> rets,
wasm::WasmCodePosition position, Node* func_index,
IsReturnCall continuation);
Node* BuildCallRef(const wasm::FunctionSig* sig, base::Vector<Node*> args,
base::Vector<Node*> rets, CheckForNull null_check,
IsReturnCall continuation,
Node* BuildCallRef(const wasm::FunctionSig* real_sig,
base::Vector<Node*> args, base::Vector<Node*> rets,
CheckForNull null_check, IsReturnCall continuation,
wasm::WasmCodePosition position);
Node* BuildF32CopySign(Node* left, Node* right);
......@@ -811,9 +820,6 @@ V8_EXPORT_PRIVATE CallDescriptor* GetWasmCallDescriptor(
V8_EXPORT_PRIVATE CallDescriptor* GetI32WasmCallDescriptor(
Zone* zone, const CallDescriptor* call_descriptor);
V8_EXPORT_PRIVATE CallDescriptor* GetI32WasmCallDescriptorForSimd(
Zone* zone, CallDescriptor* call_descriptor);
AssemblerOptions WasmAssemblerOptions();
AssemblerOptions WasmStubAssemblerOptions();
......
......@@ -12,6 +12,7 @@
#include "src/wasm/graph-builder-interface.h"
#include "src/wasm/wasm-features.h"
#include "src/wasm/wasm-module.h"
#include "src/wasm/wasm-subtyping.h"
namespace v8 {
namespace internal {
......@@ -122,7 +123,26 @@ void WasmInliner::Finalize() {
&module()->functions[candidate.inlinee_index];
base::Vector<const byte> function_bytes =
wire_bytes_->GetCode(inlinee->code);
const wasm::FunctionBody inlinee_body(inlinee->sig, inlinee->code.offset(),
// We use the signature based on the real argument types stored in the call
// node. This is more specific than the callee's formal signature and might
// enable some optimizations.
const wasm::FunctionSig* real_sig =
CallDescriptorOf(call->op())->wasm_sig();
// DCHECK that the real signature is a subtype of the formal one.
DCHECK_EQ(real_sig->parameter_count(), inlinee->sig->parameter_count());
DCHECK_EQ(real_sig->return_count(), inlinee->sig->return_count());
for (size_t i = 0; i < real_sig->parameter_count(); i++) {
DCHECK(wasm::IsSubtypeOf(real_sig->GetParam(i), inlinee->sig->GetParam(i),
module()));
}
for (size_t i = 0; i < real_sig->return_count(); i++) {
DCHECK(wasm::IsSubtypeOf(inlinee->sig->GetReturn(i),
real_sig->GetReturn(i), module()));
}
// End DCHECK.
const wasm::FunctionBody inlinee_body(real_sig, inlinee->code.offset(),
function_bytes.begin(),
function_bytes.end());
wasm::WasmFeatures detected;
......@@ -168,6 +188,7 @@ void WasmInliner::Finalize() {
} else {
InlineTailCall(call, inlinee_start, inlinee_end);
}
call->Kill();
// Returning after only one inlining has been tried and found worse.
}
}
......
......@@ -1634,6 +1634,18 @@ class WasmGraphBuildingInterface {
const Value args[], Value returns[]) {
size_t param_count = sig->parameter_count();
size_t return_count = sig->return_count();
// Construct a function signature based on the real function parameters.
FunctionSig::Builder real_sig_builder(builder_->graph_zone(), return_count,
param_count);
for (size_t i = 0; i < param_count; i++) {
real_sig_builder.AddParam(args[i].type);
}
for (size_t i = 0; i < return_count; i++) {
real_sig_builder.AddReturn(sig->GetReturn(i));
}
FunctionSig* real_sig = real_sig_builder.Build();
NodeVector arg_nodes(param_count + 1);
base::SmallVector<TFNode*, 1> return_nodes(return_count);
arg_nodes[0] = (call_info.call_mode() == CallInfo::kCallDirect)
......@@ -1648,19 +1660,20 @@ class WasmGraphBuildingInterface {
CheckForException(
decoder, builder_->CallIndirect(
call_info.table_index(), call_info.sig_index(),
base::VectorOf(arg_nodes),
real_sig, base::VectorOf(arg_nodes),
base::VectorOf(return_nodes), decoder->position()));
break;
case CallInfo::kCallDirect:
CheckForException(
decoder, builder_->CallDirect(
call_info.callee_index(), base::VectorOf(arg_nodes),
base::VectorOf(return_nodes), decoder->position()));
decoder, builder_->CallDirect(call_info.callee_index(), real_sig,
base::VectorOf(arg_nodes),
base::VectorOf(return_nodes),
decoder->position()));
break;
case CallInfo::kCallRef:
CheckForException(
decoder,
builder_->CallRef(sig, base::VectorOf(arg_nodes),
builder_->CallRef(real_sig, base::VectorOf(arg_nodes),
base::VectorOf(return_nodes),
call_info.null_check(), decoder->position()));
break;
......@@ -1677,6 +1690,17 @@ class WasmGraphBuildingInterface {
const FunctionSig* sig, const Value args[]) {
size_t arg_count = sig->parameter_count();
// Construct a function signature based on the real function parameters.
FunctionSig::Builder real_sig_builder(builder_->graph_zone(),
sig->return_count(), arg_count);
for (size_t i = 0; i < arg_count; i++) {
real_sig_builder.AddParam(args[i].type);
}
for (size_t i = 0; i < sig->return_count(); i++) {
real_sig_builder.AddReturn(sig->GetReturn(i));
}
FunctionSig* real_sig = real_sig_builder.Build();
ValueVector arg_values(arg_count + 1);
if (call_info.call_mode() == CallInfo::kCallDirect) {
arg_values[0].node = nullptr;
......@@ -1699,22 +1723,23 @@ class WasmGraphBuildingInterface {
switch (call_info.call_mode()) {
case CallInfo::kCallIndirect:
CheckForException(decoder,
builder_->ReturnCallIndirect(
call_info.table_index(), call_info.sig_index(),
base::VectorOf(arg_nodes), decoder->position()));
CheckForException(
decoder,
builder_->ReturnCallIndirect(
call_info.table_index(), call_info.sig_index(), real_sig,
base::VectorOf(arg_nodes), decoder->position()));
break;
case CallInfo::kCallDirect:
CheckForException(decoder,
builder_->ReturnCall(call_info.callee_index(),
base::VectorOf(arg_nodes),
decoder->position()));
CheckForException(
decoder, builder_->ReturnCall(call_info.callee_index(), real_sig,
base::VectorOf(arg_nodes),
decoder->position()));
break;
case CallInfo::kCallRef:
CheckForException(
decoder, builder_->ReturnCallRef(sig, base::VectorOf(arg_nodes),
call_info.null_check(),
decoder->position()));
CheckForException(decoder,
builder_->ReturnCallRef(
real_sig, base::VectorOf(arg_nodes),
call_info.null_check(), decoder->position()));
break;
}
}
......
......@@ -3,6 +3,7 @@
// found in the LICENSE file.
// Flags: --wasm-inlining --no-liftoff --experimental-wasm-return-call
// Flags: --experimental-wasm-gc
d8.file.execute("test/mjsunit/wasm/wasm-module-builder.js");
......@@ -276,3 +277,45 @@ d8.file.execute("test/mjsunit/wasm/wasm-module-builder.js");
let instance = builder.instantiate();
assertEquals(25, instance.exports.main(10));
})();
(function InlineSubtypeSignatureTest() {
print(arguments.callee.name);
let builder = new WasmModuleBuilder();
let struct = builder.addStruct([makeField(kWasmI32, true)]);
let callee = builder
.addFunction("callee", makeSig([wasmOptRefType(struct)], [kWasmI32]))
.addBody([kExprLocalGet, 0, kGCPrefix, kExprStructGet, struct, 0]);
// When inlining "callee", TF should pass the real parameter type (ref 0) and
// thus eliminate the null check for struct.get.
builder.addFunction("main", makeSig([wasmRefType(struct)], [kWasmI32]))
.addBody([kExprLocalGet, 0, kExprCallFunction, callee.index])
.exportFunc();
builder.instantiate({});
})();
(function InliningAndEscapeAnalysisTest() {
print(arguments.callee.name);
let builder = new WasmModuleBuilder();
let struct = builder.addStruct([makeField(kWasmI32, true)]);
let callee = builder
.addFunction("callee", makeSig([wasmOptRefType(struct)], [kWasmI32]))
.addBody([kExprLocalGet, 0, kGCPrefix, kExprStructGet, struct, 0]);
// The allocation should be removed.
builder.addFunction("main", kSig_i_i)
.addBody([
kExprLocalGet, 0, kExprI32Const, 1, kExprI32Add,
kGCPrefix, kExprRttCanon, struct,
kGCPrefix, kExprStructNewWithRtt, struct,
kExprCallFunction, callee.index])
.exportFunc();
let instance = builder.instantiate({});
assertEquals(11, instance.exports.main(10));
})();
......@@ -47,6 +47,9 @@ class LinkageTailCall : public TestWithZone {
0, // callee-saved fp
CallDescriptor::kNoFlags, // flags,
"", StackArgumentOrder::kDefault,
#if V8_ENABLE_WEBASSEMBLY
nullptr, // wasm function sig
#endif
0, // allocatable_registers
stack_returns);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment