wasm-inlining.cc 18.8 KB
Newer Older
1 2 3 4 5 6
// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "src/compiler/wasm-inlining.h"

7
#include "src/compiler/all-nodes.h"
8
#include "src/compiler/compiler-source-position-table.h"
9 10 11 12 13 14
#include "src/compiler/node-matchers.h"
#include "src/compiler/wasm-compiler.h"
#include "src/wasm/function-body-decoder.h"
#include "src/wasm/graph-builder-interface.h"
#include "src/wasm/wasm-features.h"
#include "src/wasm/wasm-module.h"
15
#include "src/wasm/wasm-subtyping.h"
16 17 18 19 20 21

namespace v8 {
namespace internal {
namespace compiler {

Reduction WasmInliner::Reduce(Node* node) {
22 23 24 25 26 27
  switch (node->opcode()) {
    case IrOpcode::kCall:
    case IrOpcode::kTailCall:
      return ReduceCall(node);
    default:
      return NoChange();
28 29 30
  }
}

31 32
#define TRACE(...) \
  if (FLAG_trace_wasm_inlining) PrintF(__VA_ARGS__);
33

34 35
// TODO(12166): Save inlined frames for trap/--trace-wasm purposes. Consider
//              tail calls.
36
Reduction WasmInliner::ReduceCall(Node* call) {
37 38
  DCHECK(call->opcode() == IrOpcode::kCall ||
         call->opcode() == IrOpcode::kTailCall);
39 40 41 42 43 44 45 46

  if (seen_.find(call) != seen_.end()) {
    TRACE("function %d: have already seen node %d, skipping\n", function_index_,
          call->id());
    return NoChange();
  }
  seen_.insert(call);

47 48 49 50
  Node* callee = NodeProperties::GetValueInput(call, 0);
  IrOpcode::Value reloc_opcode = mcgraph_->machine()->Is32()
                                     ? IrOpcode::kRelocatableInt32Constant
                                     : IrOpcode::kRelocatableInt64Constant;
51 52 53 54 55
  if (callee->opcode() != reloc_opcode) {
    TRACE("[function %d: considering node %d... not a relocatable constant]\n",
          function_index_, call->id());
    return NoChange();
  }
56
  auto info = OpParameter<RelocatablePtrConstantInfo>(callee->op());
57
  uint32_t inlinee_index = static_cast<uint32_t>(info.value());
58 59
  TRACE("[function %d: considering node %d, call to %d... ", function_index_,
        call->id(), inlinee_index)
60 61 62 63 64 65 66 67
  if (info.rmode() != RelocInfo::WASM_CALL) {
    TRACE("not a wasm call]\n")
    return NoChange();
  }
  if (inlinee_index < module()->num_imported_functions) {
    TRACE("imported function]\n")
    return NoChange();
  }
68 69
  if (inlinee_index == function_index_) {
    TRACE("recursive call]\n")
70 71
    return NoChange();
  }
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

  TRACE("adding to inlining candidates!]\n")

  bool is_speculative_call_ref = false;
  int call_count = 0;
  if (FLAG_wasm_speculative_inlining) {
    base::MutexGuard guard(&module()->type_feedback.mutex);
    auto maybe_feedback =
        module()->type_feedback.feedback_for_function.find(function_index_);
    if (maybe_feedback != module()->type_feedback.feedback_for_function.end()) {
      wasm::FunctionTypeFeedback feedback = maybe_feedback->second;
      wasm::WasmCodePosition position =
          source_positions_->GetSourcePosition(call).ScriptOffset();
      DCHECK_NE(position, wasm::kNoCodePosition);
      auto index_in_feedback_vector = feedback.positions.find(position);
      if (index_in_feedback_vector != feedback.positions.end()) {
        is_speculative_call_ref = true;
        call_count = feedback.feedback_vector[index_in_feedback_vector->second]
                         .absolute_call_frequency;
      }
    }
  }
94 95 96 97

  CHECK_LT(inlinee_index, module()->functions.size());
  const wasm::WasmFunction* inlinee = &module()->functions[inlinee_index];
  base::Vector<const byte> function_bytes = wire_bytes_->GetCode(inlinee->code);
98

99 100
  CandidateInfo candidate{call, inlinee_index, is_speculative_call_ref,
                          call_count, function_bytes.length()};
101

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  inlining_candidates_.push(candidate);
  return NoChange();
}

void WasmInliner::Finalize() {
  TRACE("function %d: going though inlining candidates...\n", function_index_);
  while (!inlining_candidates_.empty()) {
    CandidateInfo candidate = inlining_candidates_.top();
    inlining_candidates_.pop();
    Node* call = candidate.node;
    TRACE(
        "  [function %d: considering candidate {@%d, index=%d, type=%s, "
        "count=%d, size=%d}... ",
        function_index_, call->id(), candidate.inlinee_index,
        candidate.is_speculative_call_ref ? "ref" : "direct",
        candidate.call_count, candidate.wire_byte_size);
    if (call->IsDead()) {
      TRACE("dead node]\n");
      continue;
    }
    const wasm::WasmFunction* inlinee =
        &module()->functions[candidate.inlinee_index];
    base::Vector<const byte> function_bytes =
        wire_bytes_->GetCode(inlinee->code);
126 127 128 129 130 131
    // We use the signature based on the real argument types stored in the call
    // node. This is more specific than the callee's formal signature and might
    // enable some optimizations.
    const wasm::FunctionSig* real_sig =
        CallDescriptorOf(call->op())->wasm_sig();

132 133 134 135 136 137
#if DEBUG
    // Check that the real signature is a subtype of the formal one.
    const wasm::FunctionSig* formal_sig =
        WasmGraphBuilder::Int64LoweredSig(zone(), inlinee->sig);
    CHECK_EQ(real_sig->parameter_count(), formal_sig->parameter_count());
    CHECK_EQ(real_sig->return_count(), formal_sig->return_count());
138
    for (size_t i = 0; i < real_sig->parameter_count(); i++) {
139 140
      CHECK(wasm::IsSubtypeOf(real_sig->GetParam(i), formal_sig->GetParam(i),
                              module()));
141 142
    }
    for (size_t i = 0; i < real_sig->return_count(); i++) {
143 144
      CHECK(wasm::IsSubtypeOf(formal_sig->GetReturn(i), real_sig->GetReturn(i),
                              module()));
145
    }
146
#endif
147 148

    const wasm::FunctionBody inlinee_body(real_sig, inlinee->code.offset(),
149 150 151 152 153
                                          function_bytes.begin(),
                                          function_bytes.end());
    wasm::WasmFeatures detected;
    WasmGraphBuilder builder(env_, zone(), mcgraph_, inlinee_body.sig,
                             source_positions_);
154
    std::vector<WasmLoopInfo> inlinee_loop_infos;
155 156 157 158 159 160

    size_t subgraph_min_node_id = graph()->NodeCount();
    Node* inlinee_start;
    Node* inlinee_end;
    {
      Graph::SubgraphScope scope(graph());
161 162 163 164 165 166 167
      wasm::DecodeResult result = wasm::BuildTFGraph(
          zone()->allocator(), env_->enabled_features, module(), &builder,
          &detected, inlinee_body, &inlinee_loop_infos, node_origins_,
          candidate.inlinee_index,
          NodeProperties::IsExceptionalCall(call)
              ? wasm::kInlinedHandledCall
              : wasm::kInlinedNonHandledCall);
168 169 170 171 172 173 174 175
      if (result.failed()) {
        // This can happen if the inlinee has never been compiled before and is
        // invalid. Return, as there is no point to keep optimizing.
        TRACE("failed to compile]\n")
        return;
      }

      builder.LowerInt64(WasmGraphBuilder::kCalledFromWasm);
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
      inlinee_start = graph()->start();
      inlinee_end = graph()->end();
    }

    size_t additional_nodes = graph()->NodeCount() - subgraph_min_node_id;
    if (current_graph_size_ + additional_nodes >
        size_limit(initial_graph_size_)) {
      // This is not based on the accurate graph size, as it may have been
      // shrunk by other optimizations. We could recompute the accurate size
      // with a traversal, but it is most probably not worth the time.
      TRACE("not enough inlining budget]\n");
      continue;
    }
    TRACE("inlining!]\n");
    current_graph_size_ += additional_nodes;

    if (call->opcode() == IrOpcode::kCall) {
      InlineCall(call, inlinee_start, inlinee_end, inlinee->sig,
                 subgraph_min_node_id);
    } else {
      InlineTailCall(call, inlinee_start, inlinee_end);
    }
198
    call->Kill();
199 200
    loop_infos_->insert(loop_infos_->end(), inlinee_loop_infos.begin(),
                        inlinee_loop_infos.end());
201 202
    // Returning after only one inlining has been tried and found worse.
  }
203 204
}

205 206 207 208 209
/* Rewire callee formal parameters to the call-site real parameters. Rewire
 * effect and control dependencies of callee's start node with the respective
 * inputs of the call node.
 */
void WasmInliner::RewireFunctionEntry(Node* call, Node* callee_start) {
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  Node* control = NodeProperties::GetControlInput(call);
  Node* effect = NodeProperties::GetEffectInput(call);

  for (Edge edge : callee_start->use_edges()) {
    Node* use = edge.from();
    switch (use->opcode()) {
      case IrOpcode::kParameter: {
        // Index 0 is the callee node.
        int index = 1 + ParameterIndexOf(use->op());
        Replace(use, NodeProperties::GetValueInput(call, index));
        break;
      }
      default:
        if (NodeProperties::IsEffectEdge(edge)) {
          edge.UpdateTo(effect);
        } else if (NodeProperties::IsControlEdge(edge)) {
226 227 228 229 230
          // Projections pointing to the inlinee start are floating control.
          // They should point to the graph's start.
          edge.UpdateTo(use->opcode() == IrOpcode::kProjection
                            ? graph()->start()
                            : control);
231 232 233
        } else {
          UNREACHABLE();
        }
234
        Revisit(edge.from());
235 236 237
        break;
    }
  }
238 239
}

240 241
void WasmInliner::InlineTailCall(Node* call, Node* callee_start,
                                 Node* callee_end) {
242
  DCHECK_EQ(call->opcode(), IrOpcode::kTailCall);
243 244 245 246 247 248 249
  // 1) Rewire function entry.
  RewireFunctionEntry(call, callee_start);
  // 2) For tail calls, all we have to do is rewire all terminators of the
  // inlined graph to the end of the caller graph.
  for (Node* const input : callee_end->inputs()) {
    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
    NodeProperties::MergeControlToEnd(graph(), common(), input);
250 251 252 253
  }
  for (Edge edge_to_end : call->use_edges()) {
    DCHECK_EQ(edge_to_end.from(), graph()->end());
    edge_to_end.UpdateTo(mcgraph()->Dead());
254 255
  }
  callee_end->Kill();
256 257
  call->Kill();
  Revisit(graph()->end());
258 259
}

260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
namespace {
// graph-builder-interface generates a dangling exception handler for each
// throwing call in the inlinee. This might be followed by a LoopExit node.
Node* DanglingHandler(Node* call) {
  Node* if_exception = nullptr;
  for (Node* use : call->uses()) {
    if (use->opcode() == IrOpcode::kIfException) {
      if_exception = use;
      break;
    }
  }
  DCHECK_NOT_NULL(if_exception);

  // If this handler is dangling, return it.
  if (if_exception->UseCount() == 0) return if_exception;

  for (Node* use : if_exception->uses()) {
    // Otherwise, look for a LoopExit use of this handler.
    if (use->opcode() == IrOpcode::kLoopExit) {
      for (Node* loop_exit_use : use->uses()) {
        if (loop_exit_use->opcode() != IrOpcode::kLoopExitEffect &&
            loop_exit_use->opcode() != IrOpcode::kLoopExitValue) {
          // This LoopExit has a use other than LoopExitEffect/Value, so it is
          // not dangling.
          return nullptr;
        }
      }
      return use;
    }
  }

  return nullptr;
}
}  // namespace

295 296 297
void WasmInliner::InlineCall(Node* call, Node* callee_start, Node* callee_end,
                             const wasm::FunctionSig* inlinee_sig,
                             size_t subgraph_min_node_id) {
298
  DCHECK_EQ(call->opcode(), IrOpcode::kCall);
299 300 301 302

  // 0) Before doing anything, if {call} has an exception handler, collect all
  // unhandled calls in the subgraph.
  Node* handler = nullptr;
303
  std::vector<Node*> dangling_handlers;
304 305 306 307
  if (NodeProperties::IsExceptionalCall(call, &handler)) {
    AllNodes subgraph_nodes(zone(), callee_end, graph());
    for (Node* node : subgraph_nodes.reachable) {
      if (node->id() >= subgraph_min_node_id &&
308 309 310 311 312
          !node->op()->HasProperty(Operator::kNoThrow)) {
        Node* dangling_handler = DanglingHandler(node);
        if (dangling_handler != nullptr) {
          dangling_handlers.push_back(dangling_handler);
        }
313 314 315 316
      }
    }
  }

317 318
  // 1) Rewire function entry.
  RewireFunctionEntry(call, callee_start);
319

320
  // 2) Handle all graph terminators for the callee.
321 322 323 324 325
  NodeVector return_nodes(zone());
  for (Node* const input : callee_end->inputs()) {
    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
    switch (input->opcode()) {
      case IrOpcode::kReturn:
326
        // Returns are collected to be rewired into the caller graph later.
327 328 329 330 331 332 333 334
        return_nodes.push_back(input);
        break;
      case IrOpcode::kDeoptimize:
      case IrOpcode::kTerminate:
      case IrOpcode::kThrow:
        NodeProperties::MergeControlToEnd(graph(), common(), input);
        Revisit(graph()->end());
        break;
335
      case IrOpcode::kTailCall: {
336 337 338
        // A tail call in the callee inlined in a regular call in the caller has
        // to be transformed into a regular call, and then returned from the
        // inlinee. It will then be handled like any other return.
339 340
        auto descriptor = CallDescriptorOf(input->op());
        NodeProperties::ChangeOp(input, common()->Call(descriptor));
341
        int return_arity = static_cast<int>(inlinee_sig->return_count());
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
        NodeVector return_inputs(zone());
        // The first input of a return node is always the 0 constant.
        return_inputs.push_back(graph()->NewNode(common()->Int32Constant(0)));
        if (return_arity == 1) {
          return_inputs.push_back(input);
        } else if (return_arity > 1) {
          for (int i = 0; i < return_arity; i++) {
            return_inputs.push_back(
                graph()->NewNode(common()->Projection(i), input, input));
          }
        }

        // Add effect and control inputs.
        return_inputs.push_back(input->op()->EffectOutputCount() > 0
                                    ? input
                                    : NodeProperties::GetEffectInput(input));
        return_inputs.push_back(input->op()->ControlOutputCount() > 0
                                    ? input
                                    : NodeProperties::GetControlInput(input));

        Node* ret = graph()->NewNode(common()->Return(return_arity),
                                     static_cast<int>(return_inputs.size()),
                                     return_inputs.data());
        return_nodes.push_back(ret);
        break;
      }
368 369 370 371
      default:
        UNREACHABLE();
    }
  }
372
  callee_end->Kill();
373

374
  // 3) Rewire unhandled calls to the handler.
375
  int handler_count = static_cast<int>(dangling_handlers.size());
376

377
  if (handler_count > 0) {
378
    Node* control_output =
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
        graph()->NewNode(common()->Merge(handler_count), handler_count,
                         dangling_handlers.data());
    std::vector<Node*> effects;
    std::vector<Node*> values;
    for (Node* control : dangling_handlers) {
      if (control->opcode() == IrOpcode::kIfException) {
        effects.push_back(control);
        values.push_back(control);
      } else {
        DCHECK_EQ(control->opcode(), IrOpcode::kLoopExit);
        Node* if_exception = control->InputAt(0);
        DCHECK_EQ(if_exception->opcode(), IrOpcode::kIfException);
        effects.push_back(graph()->NewNode(common()->LoopExitEffect(),
                                           if_exception, control));
        values.push_back(graph()->NewNode(
            common()->LoopExitValue(MachineRepresentation::kTagged),
            if_exception, control));
      }
    }

    effects.push_back(control_output);
    values.push_back(control_output);
401
    Node* value_output = graph()->NewNode(
402 403 404 405
        common()->Phi(MachineRepresentation::kTagged, handler_count),
        handler_count + 1, values.data());
    Node* effect_output = graph()->NewNode(common()->EffectPhi(handler_count),
                                           handler_count + 1, effects.data());
406 407 408 409 410 411 412
    ReplaceWithValue(handler, value_output, effect_output, control_output);
  } else if (handler != nullptr) {
    // Nothing in the inlined function can throw. Remove the handler.
    ReplaceWithValue(handler, mcgraph()->Dead(), mcgraph()->Dead(),
                     mcgraph()->Dead());
  }

413
  if (return_nodes.size() > 0) {
414
    /* 4) Collect all return site value, effect, and control inputs into phis
415
     * and merges. */
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    int const return_count = static_cast<int>(return_nodes.size());
    NodeVector controls(zone());
    NodeVector effects(zone());
    for (Node* const return_node : return_nodes) {
      controls.push_back(NodeProperties::GetControlInput(return_node));
      effects.push_back(NodeProperties::GetEffectInput(return_node));
    }
    Node* control_output = graph()->NewNode(common()->Merge(return_count),
                                            return_count, &controls.front());
    effects.push_back(control_output);
    Node* effect_output =
        graph()->NewNode(common()->EffectPhi(return_count),
                         static_cast<int>(effects.size()), &effects.front());

    // The first input of a return node is discarded. This is because Wasm
    // functions always return an additional 0 constant as a first return value.
    DCHECK(
        Int32Matcher(NodeProperties::GetValueInput(return_nodes[0], 0)).Is(0));
    int const return_arity = return_nodes[0]->op()->ValueInputCount() - 1;
    NodeVector values(zone());
    for (int i = 0; i < return_arity; i++) {
      NodeVector ith_values(zone());
      for (Node* const return_node : return_nodes) {
        Node* value = NodeProperties::GetValueInput(return_node, i + 1);
        ith_values.push_back(value);
      }
      ith_values.push_back(control_output);
      // Find the correct machine representation for the return values from the
      // inlinee signature.
      MachineRepresentation repr =
446
          inlinee_sig->GetReturn(i).machine_representation();
447 448 449 450 451
      Node* ith_value_output = graph()->NewNode(
          common()->Phi(repr, return_count),
          static_cast<int>(ith_values.size()), &ith_values.front());
      values.push_back(ith_value_output);
    }
452
    for (Node* return_node : return_nodes) return_node->Kill();
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

    if (return_arity == 0) {
      // Void function, no value uses.
      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
    } else if (return_arity == 1) {
      // One return value. Just replace value uses of the call node with it.
      ReplaceWithValue(call, values[0], effect_output, control_output);
    } else {
      // Multiple returns. We have to find the projections of the call node and
      // replace them with the returned values.
      for (Edge use_edge : call->use_edges()) {
        if (NodeProperties::IsValueEdge(use_edge)) {
          Node* use = use_edge.from();
          DCHECK_EQ(use->opcode(), IrOpcode::kProjection);
          ReplaceWithValue(use, values[ProjectionIndexOf(use->op())]);
        }
      }
      // All value inputs are replaced by the above loop, so it is ok to use
      // Dead() as a dummy for value replacement.
      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
    }
  } else {
    // The callee can never return. The call node and all its uses are dead.
    ReplaceWithValue(call, mcgraph()->Dead(), mcgraph()->Dead(),
                     mcgraph()->Dead());
  }
}

const wasm::WasmModule* WasmInliner::module() const { return env_->module; }

483 484
#undef TRACE

485 486 487
}  // namespace compiler
}  // namespace internal
}  // namespace v8