Commit de33e4ba authored by ritesht's avatar ritesht Committed by Commit bot

[wasm] Adding feature to JIT a wasm function at runtime and hook up the...

[wasm] Adding feature to JIT a wasm function at runtime and hook up the compiled code into the indirect function table

The runtime JIT function is passed in the function table to hook up the compiled code and the starting address of the memory to locate the bytes to be compiled.

BUG=5044

Review-Url: https://codereview.chromium.org/2137993003
Cr-Commit-Position: refs/heads/master@{#37735}
parent b0711ccc
......@@ -361,6 +361,10 @@ Node* WasmGraphBuilder::NumberConstant(int32_t value) {
return jsgraph()->Constant(value);
}
Node* WasmGraphBuilder::Uint32Constant(uint32_t value) {
return jsgraph()->Uint32Constant(value);
}
Node* WasmGraphBuilder::Int32Constant(int32_t value) {
return jsgraph()->Int32Constant(value);
}
......@@ -2094,6 +2098,100 @@ Node* WasmGraphBuilder::CallIndirect(uint32_t index, Node** args,
return BuildWasmCall(sig, args, position);
}
Node* WasmGraphBuilder::JITSingleFunction(Node* const base, Node* const length,
Node* const index,
const uint32_t sig_index,
wasm::FunctionSig* const sig,
wasm::WasmCodePosition position) {
MachineOperatorBuilder* machine = jsgraph()->machine();
// Bounds check the memory access
{
Node* base_negative =
graph()->NewNode(machine->Uint32LessThan(), base, Int32Constant(0));
trap_->AddTrapIfTrue(wasm::kTrapMemOutOfBounds, base_negative, position);
Node* length_negative = graph()->NewNode(machine->Uint32LessThanOrEqual(),
length, Int32Constant(0));
trap_->AddTrapIfTrue(wasm::kTrapFuncInvalid, length_negative, position);
Node* in_bounds = graph()->NewNode(
machine->Uint32LessThanOrEqual(),
graph()->NewNode(machine->Int32Add(), base, length), MemSize(0));
trap_->AddTrapIfFalse(wasm::kTrapMemOutOfBounds, in_bounds, position);
}
// Bounds check the index.
{
int table_size = static_cast<int>(module_->FunctionTableSize());
if (table_size > 0) {
// Bounds check against the table size.
Node* size = Int32Constant(static_cast<int>(table_size));
Node* in_bounds =
graph()->NewNode(machine->Uint32LessThan(), index, size);
trap_->AddTrapIfFalse(wasm::kTrapInvalidIndex, in_bounds, position);
} else {
// No function table. Generate a trap and return a constant.
trap_->AddTrapIfFalse(wasm::kTrapFuncInvalid, Int32Constant(0), position);
return trap_->GetTrapValue(module_->GetSignature(sig_index));
}
}
const size_t runtime_input_params = 7;
const size_t runtime_env_params = 5;
Runtime::FunctionId f = Runtime::kJITSingleFunction;
const Runtime::Function* fun = Runtime::FunctionForId(f);
// CEntryStubConstant nodes have to be created and cached in the main
// thread. At the moment this is only done for CEntryStubConstant(1).
DCHECK_EQ(1, fun->result_size);
const uint32_t return_count = static_cast<uint32_t>(sig->return_count());
const uint32_t parameter_count =
static_cast<uint32_t>(sig->parameter_count());
const uint32_t inputs_size = runtime_input_params + runtime_env_params +
return_count + parameter_count;
Node** inputs = Buffer(inputs_size);
inputs[0] = jsgraph()->CEntryStubConstant(fun->result_size);
inputs[1] = BuildChangeUint32ToSmi(base);
inputs[2] = BuildChangeUint32ToSmi(length);
inputs[3] = BuildChangeUint32ToSmi(index);
inputs[4] = FunctionTable();
inputs[5] = Uint32Constant(sig_index);
inputs[6] = BuildChangeUint32ToSmi(Uint32Constant(return_count));
// Pass in parameters and return types in to the runtime function
// to allow it to regenerate signature
for (uint32_t i = 0; i < return_count; ++i) {
inputs[i + runtime_input_params] = BuildChangeUint32ToSmi(
Uint32Constant(static_cast<int>(sig->GetReturn(i))));
}
for (uint32_t i = 0; i < parameter_count; ++i) {
inputs[i + runtime_input_params + return_count] = BuildChangeUint32ToSmi(
Uint32Constant(static_cast<int>(sig->GetParam(i))));
}
const uint32_t args_offset = inputs_size - runtime_env_params;
inputs[args_offset] = jsgraph()->ExternalConstant(
ExternalReference(f, jsgraph()->isolate())); // ref
inputs[args_offset + 1] = jsgraph()->Int32Constant(args_offset - 1); // arity
inputs[args_offset + 2] =
HeapConstant(module_->instance->context); // context
inputs[args_offset + 3] = *effect_;
inputs[args_offset + 4] = *control_;
// Use the module context to call the runtime.
CallDescriptor* desc = Linkage::GetRuntimeCallDescriptor(
jsgraph()->zone(), f, args_offset - 1, Operator::kNoProperties,
CallDescriptor::kNoFlags);
Node* node =
graph()->NewNode(jsgraph()->common()->Call(desc), inputs_size, inputs);
*control_ = node;
*effect_ = node;
return node;
}
Node* WasmGraphBuilder::BuildI32Rol(Node* left, Node* right) {
// Implement Rol by Ror since TurboFan does not have Rol opcode.
// TODO(weiliang): support Word32Rol opcode in TurboFan.
......
......@@ -118,6 +118,7 @@ class WasmGraphBuilder {
Node* Phi(wasm::LocalType type, unsigned count, Node** vals, Node* control);
Node* EffectPhi(unsigned count, Node** effects, Node* control);
Node* NumberConstant(int32_t value);
Node* Uint32Constant(uint32_t value);
Node* Int32Constant(int32_t value);
Node* Int64Constant(int64_t value);
Node* Float32Constant(float value);
......@@ -149,6 +150,9 @@ class WasmGraphBuilder {
wasm::WasmCodePosition position);
Node* CallIndirect(uint32_t index, Node** args,
wasm::WasmCodePosition position);
Node* JITSingleFunction(Node* base, Node* length, Node* index,
uint32_t sig_index, wasm::FunctionSig* sig,
wasm::WasmCodePosition position);
void BuildJSToWasmWrapper(Handle<Code> wasm_code, wasm::FunctionSig* sig);
void BuildWasmToJSWrapper(Handle<JSFunction> function,
wasm::FunctionSig* sig);
......
......@@ -492,7 +492,8 @@ class CallSite {
T(WasmTrapFloatUnrepresentable, "integer result unrepresentable") \
T(WasmTrapFuncInvalid, "invalid function") \
T(WasmTrapFuncSigMismatch, "function signature mismatch") \
T(WasmTrapMemAllocationFail, "failed to allocate memory")
T(WasmTrapMemAllocationFail, "failed to allocate memory") \
T(WasmTrapInvalidIndex, "invalid index into function table")
class MessageTemplate {
public:
......
......@@ -80,6 +80,14 @@ namespace internal {
int32_t name = 0; \
CHECK(args[index]->ToInt32(&name));
// Assert that the given argument is a number within the Uint32 range
// and convert it to uint32_t. If the argument is not an Uint32 call
// IllegalOperation and return.
#define CONVERT_UINT32_ARG_CHECKED(name, index) \
CHECK(args[index]->IsNumber()); \
uint32_t name = 0; \
CHECK(args[index]->ToUint32(&name));
// Cast the given argument to PropertyAttributes and store its value in a
// variable with the given name. If the argument is not a Smi or the
// enum value is out of range, we crash safely.
......
......@@ -6,6 +6,7 @@
#include "src/arguments.h"
#include "src/assembler.h"
#include "src/compiler/wasm-compiler.h"
#include "src/conversions.h"
#include "src/debug/debug.h"
#include "src/factory.h"
......@@ -17,6 +18,10 @@
namespace v8 {
namespace internal {
namespace {
const int kWasmMemArrayBuffer = 2;
}
RUNTIME_FUNCTION(Runtime_WasmGrowMemory) {
HandleScope scope(isolate);
DCHECK_EQ(1, args.length());
......@@ -40,7 +45,6 @@ RUNTIME_FUNCTION(Runtime_WasmGrowMemory) {
Address old_mem_start, new_mem_start;
uint32_t old_size, new_size;
const int kWasmMemArrayBuffer = 2;
// Get mem buffer associated with module object
Handle<Object> obj(module_object->GetInternalField(kWasmMemArrayBuffer),
......@@ -110,5 +114,81 @@ RUNTIME_FUNCTION(Runtime_WasmGrowMemory) {
wasm::WasmModule::kPageSize);
}
RUNTIME_FUNCTION(Runtime_JITSingleFunction) {
const int fixed_args = 6;
HandleScope scope(isolate);
DCHECK_LE(fixed_args, args.length());
CONVERT_SMI_ARG_CHECKED(base, 0);
CONVERT_SMI_ARG_CHECKED(length, 1);
CONVERT_SMI_ARG_CHECKED(index, 2);
CONVERT_ARG_HANDLE_CHECKED(FixedArray, function_table, 3);
CONVERT_UINT32_ARG_CHECKED(sig_index, 4);
CONVERT_SMI_ARG_CHECKED(return_count, 5);
Handle<JSObject> module_object;
{
// Get the module JSObject
DisallowHeapAllocation no_allocation;
const Address entry = Isolate::c_entry_fp(isolate->thread_local_top());
Address pc =
Memory::Address_at(entry + StandardFrameConstants::kCallerPCOffset);
Code* code =
isolate->inner_pointer_to_code_cache()->GetCacheEntry(pc)->code;
FixedArray* deopt_data = code->deoptimization_data();
DCHECK(deopt_data->length() == 2);
module_object = Handle<JSObject>::cast(handle(deopt_data->get(0), isolate));
CHECK(!module_object->IsNull(isolate));
}
// Get mem buffer associated with module object
Handle<Object> obj(module_object->GetInternalField(kWasmMemArrayBuffer),
isolate);
if (obj->IsUndefined(isolate)) {
return isolate->heap()->undefined_value();
}
Handle<JSArrayBuffer> mem_buffer = Handle<JSArrayBuffer>::cast(obj);
wasm::WasmModule module(reinterpret_cast<byte*>(mem_buffer->backing_store()));
wasm::ErrorThrower thrower(isolate, "JITSingleFunction");
wasm::ModuleEnv module_env;
module_env.module = &module;
module_env.instance = nullptr;
module_env.origin = wasm::kWasmOrigin;
uint32_t signature_size = args.length() - fixed_args;
wasm::LocalType* sig_types = new wasm::LocalType[signature_size];
for (uint32_t i = 0; i < signature_size; ++i) {
CONVERT_SMI_ARG_CHECKED(sig_type, i + fixed_args);
sig_types[i] = static_cast<wasm::LocalType>(sig_type);
}
wasm::FunctionSig sig(return_count, signature_size - return_count, sig_types);
wasm::WasmFunction func;
func.sig = &sig;
func.func_index = index;
func.sig_index = sig_index;
func.name_offset = 0;
func.name_length = 0;
func.code_start_offset = base;
func.code_end_offset = base + length;
Handle<Code> code = compiler::WasmCompilationUnit::CompileWasmFunction(
&thrower, isolate, &module_env, &func);
delete[] sig_types;
if (thrower.error()) {
return isolate->heap()->undefined_value();
}
function_table->set(index, Smi::FromInt(sig_index));
function_table->set(index + function_table->length() / 2, *code);
return isolate->heap()->undefined_value();
}
} // namespace internal
} // namespace v8
......@@ -30,6 +30,8 @@ namespace internal {
// Entries have the form F(name, number of arguments, number of values):
// A variable number of arguments is specified by a -1, additional restrictions
// are specified by inline comments
#define FOR_EACH_INTRINSIC_ARRAY(F) \
F(FinishArrayPrototypeSetup, 1, 1) \
......@@ -294,6 +296,7 @@ namespace internal {
F(ThrowGeneratorRunning, 0, 1) \
F(ThrowStackOverflow, 0, 1) \
F(ThrowWasmError, 2, 1) \
F(JITSingleFunction, -1 /* >= 7 */, 1) \
F(PromiseRejectEvent, 3, 1) \
F(PromiseRevokeReject, 1, 1) \
F(StackGuard, 0, 1) \
......
......@@ -193,6 +193,23 @@ class WasmDecoder : public Decoder {
return false;
}
inline bool Complete(const byte* pc, JITSingleFunctionOperand& operand) {
ModuleEnv* m = module_;
if (m && m->module && operand.sig_index < m->module->signatures.size()) {
operand.sig = m->module->signatures[operand.sig_index];
return true;
}
return false;
}
inline bool Validate(const byte* pc, JITSingleFunctionOperand& operand) {
if (Complete(pc, operand)) {
return true;
}
error(pc, pc + 1, "invalid signature index");
return false;
}
inline bool Validate(const byte* pc, BreakDepthOperand& operand,
ZoneVector<Control>& control) {
if (operand.arity > 1) {
......@@ -287,6 +304,8 @@ class WasmDecoder : public Decoder {
ReturnArityOperand operand(this, pc);
return operand.arity;
}
case kExprJITSingleFunction:
return 3;
#define DECLARE_OPCODE_CASE(name, opcode, sig) \
case kExpr##name: \
......@@ -340,6 +359,10 @@ class WasmDecoder : public Decoder {
return 1 + operand.length;
}
case kExprJITSingleFunction: {
JITSingleFunctionOperand operand(this, pc);
return 1 + operand.length;
}
case kExprSetLocal:
case kExprGetLocal: {
LocalIndexOperand operand(this, pc);
......@@ -996,6 +1019,21 @@ class WasmFullDecoder : public WasmDecoder {
len = 1 + operand.length;
break;
}
case kExprJITSingleFunction: {
if (FLAG_wasm_jit_prototype) {
JITSingleFunctionOperand operand(this, pc_);
if (Validate(pc_, operand)) {
Value index = Pop(2, kAstI32);
Value length = Pop(1, kAstI32);
Value base = Pop(0, kAstI32);
TFNode* call =
BUILD(JITSingleFunction, base.node, length.node, index.node,
operand.sig_index, operand.sig, position());
Push(kAstI32, call);
break;
}
}
}
default:
error("Invalid opcode");
return;
......@@ -1620,6 +1658,13 @@ bool PrintAst(base::AccountingAllocator* allocator, const FunctionBody& body,
}
break;
}
case kExprJITSingleFunction: {
JITSingleFunctionOperand operand(&i, i.pc());
if (decoder.Complete(i.pc(), operand)) {
os << " // sig #" << operand.sig_index << ": " << *operand.sig;
}
break;
}
case kExprReturn: {
ReturnArityOperand operand(&i, i.pc());
os << " // arity=" << operand.arity;
......
......@@ -150,6 +150,16 @@ struct CallImportOperand {
}
};
struct JITSingleFunctionOperand {
uint32_t sig_index;
FunctionSig* sig;
unsigned length;
inline JITSingleFunctionOperand(Decoder* decoder, const byte* pc) {
sig_index = decoder->checked_read_u32v(pc, 1, &length, "signature index");
sig = nullptr;
}
};
struct BranchTableOperand {
uint32_t arity;
uint32_t table_count;
......
......@@ -449,8 +449,8 @@ void FlushAssemblyCache(Isolate* isolate, Handle<FixedArray> functions) {
} // namespace
WasmModule::WasmModule()
: module_start(nullptr),
WasmModule::WasmModule(byte* module_start)
: module_start(module_start),
module_end(nullptr),
min_mem_pages(0),
max_mem_pages(0),
......
......@@ -190,7 +190,8 @@ struct WasmModule {
// switch to libc-2.21 or higher.
base::SmartPointer<base::Semaphore> pending_tasks;
WasmModule();
WasmModule() : WasmModule(nullptr) {}
explicit WasmModule(byte* module_start);
// Get a string stored in the module bytes representing a name.
WasmName GetName(uint32_t offset, uint32_t length) const {
......
......@@ -406,6 +406,9 @@ const WasmCodePosition kNoCodePosition = -1;
V(S128Xor, 0xe578, s_ss) \
V(S128Not, 0xe579, s_s)
// For enabling JIT functionality
#define FOREACH_JIT_OPCODE(V) V(JITSingleFunction, 0xf0, _)
// All opcodes.
#define FOREACH_OPCODE(V) \
FOREACH_CONTROL_OPCODE(V) \
......@@ -416,7 +419,8 @@ const WasmCodePosition kNoCodePosition = -1;
FOREACH_LOAD_MEM_OPCODE(V) \
FOREACH_MISC_MEM_OPCODE(V) \
FOREACH_ASMJS_COMPAT_OPCODE(V) \
FOREACH_SIMD_OPCODE(V)
FOREACH_SIMD_OPCODE(V) \
FOREACH_JIT_OPCODE(V)
// All signatures.
#define FOREACH_SIGNATURE(V) \
......@@ -478,7 +482,8 @@ enum WasmOpcode {
V(TrapFloatUnrepresentable) \
V(TrapFuncInvalid) \
V(TrapFuncSigMismatch) \
V(TrapMemAllocationFail)
V(TrapMemAllocationFail) \
V(TrapInvalidIndex)
enum TrapReason {
#define DECLARE_ENUM(name) k##name,
......
// Copyright 2015 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Flags: --expose-wasm
// Flags: --wasm-jit-prototype
load("test/mjsunit/wasm/wasm-constants.js");
load("test/mjsunit/wasm/wasm-module-builder.js");
var module = (function () {
var builder = new WasmModuleBuilder();
var kSig_i_iiiii =
makeSig([kAstI32, kAstI32, kAstI32, kAstI32, kAstI32], [kAstI32]);
var sig_index2 = builder.addType(kSig_i_ii);
var sig_index5 = builder.addType(kSig_i_iiiii);
builder.addMemory(1, 1, true);
var wasm_bytes_sub = [
0,
kExprGetLocal, 0,
kExprGetLocal, 1,
kExprI32Sub
];
builder.addDataSegment(0, wasm_bytes_sub, false);
var wasm_bytes_mul = [
0,
kExprGetLocal, 0,
kExprGetLocal, 1,
kExprI32Mul
];
builder.addDataSegment(6, wasm_bytes_mul, false);
builder.addPadFunctionTable(10);
builder.addImport("add", sig_index2);
builder.addFunction("add", sig_index2)
.addBody([
kExprGetLocal, 0, kExprGetLocal, 1, kExprCallImport, kArity2, 0
]);
builder.addFunction("main", sig_index5)
.addBody([
kExprGetLocal, 0,
kExprGetLocal, 1,
kExprGetLocal, 2,
kExprJITSingleFunction, sig_index2,
kExprGetLocal, 2,
kExprGetLocal, 3,
kExprGetLocal, 4,
kExprCallIndirect, kArity2, sig_index2
])
.exportFunc()
builder.appendToTable([0, 1]);
return builder.instantiate({add: function(a, b) { return a + b | 0; }});
})();
// Check that the module exists
assertFalse(module === undefined);
assertFalse(module === null);
assertFalse(module === 0);
assertEquals("object", typeof module.exports);
assertEquals("function", typeof module.exports.main);
// Check that the bytes referred to lie in the bounds of the memory buffer
assertTraps(kTrapMemOutOfBounds, "module.exports.main(0, 100000, 3, 55, 99)");
assertTraps(kTrapMemOutOfBounds, "module.exports.main(65536, 1, 3, 55, 99)");
// Check that the index lies in the bounds of the table size
assertTraps(kTrapInvalidIndex, "module.exports.main(0, 6, 10, 55, 99)");
assertTraps(kTrapInvalidIndex, "module.exports.main(0, 6, -1, 55, 99)");
// args: base offset, size of func_bytes, index, param1, param2
assertEquals(-444, module.exports.main(0, 6, 3, 555, 999)); // JIT sub function
assertEquals(13, module.exports.main(0, 6, 9, 45, 32)); // JIT sub function
assertEquals(187, module.exports.main(6, 6, 6, 17, 11)); // JIT mul function
assertEquals(30525, module.exports.main(6, 6, 9, 555, 55)); // JIT mul function
......@@ -304,6 +304,8 @@ var kExprI32Rol = 0xb7;
var kExprI64Ror = 0xb8;
var kExprI64Rol = 0xb9;
var kExprJITSingleFunction = 0xf0;
var kTrapUnreachable = 0;
var kTrapMemOutOfBounds = 1;
var kTrapDivByZero = 2;
......@@ -313,6 +315,7 @@ var kTrapFloatUnrepresentable = 5;
var kTrapFuncInvalid = 6;
var kTrapFuncSigMismatch = 7;
var kTrapMemAllocationFail = 8;
var kTrapInvalidIndex = 9;
var kTrapMsgs = [
"unreachable",
......@@ -323,7 +326,8 @@ var kTrapMsgs = [
"integer result unrepresentable",
"invalid function",
"function signature mismatch",
"failed to allocate memory"
"failed to allocate memory",
"invalid index into function table"
];
function assertTraps(trap, code) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment