wasm-inlining.cc 20.8 KB
Newer Older
1 2 3 4 5 6
// Copyright 2021 the V8 project authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "src/compiler/wasm-inlining.h"

7
#include "src/compiler/all-nodes.h"
8
#include "src/compiler/compiler-source-position-table.h"
9 10 11 12 13 14
#include "src/compiler/node-matchers.h"
#include "src/compiler/wasm-compiler.h"
#include "src/wasm/function-body-decoder.h"
#include "src/wasm/graph-builder-interface.h"
#include "src/wasm/wasm-features.h"
#include "src/wasm/wasm-module.h"
15
#include "src/wasm/wasm-subtyping.h"
16 17 18 19 20 21

namespace v8 {
namespace internal {
namespace compiler {

Reduction WasmInliner::Reduce(Node* node) {
22 23 24 25 26 27
  switch (node->opcode()) {
    case IrOpcode::kCall:
    case IrOpcode::kTailCall:
      return ReduceCall(node);
    default:
      return NoChange();
28 29 30
  }
}

31
#define TRACE(...) \
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  if (FLAG_trace_wasm_inlining) PrintF(__VA_ARGS__)

void WasmInliner::Trace(Node* call, int inlinee, const char* decision) {
  TRACE("[function %d: considering node %d, call to %d: %s]\n", function_index_,
        call->id(), inlinee, decision);
}

uint32_t WasmInliner::FindOriginatingFunction(Node* call) {
  DCHECK_EQ(inlined_functions_.size(), first_node_id_.size());
  NodeId id = call->id();
  if (inlined_functions_.size() == 0 || id < first_node_id_[0]) {
    return function_index_;
  }
  for (size_t i = 1; i < first_node_id_.size(); i++) {
    if (id < first_node_id_[i]) return inlined_functions_[i - 1];
  }
  DCHECK_GE(id, first_node_id_.back());
  return inlined_functions_.back();
}

int WasmInliner::GetCallCount(Node* call) {
  if (!FLAG_wasm_speculative_inlining) return 0;
  base::MutexGuard guard(&module()->type_feedback.mutex);
  wasm::WasmCodePosition position =
      source_positions_->GetSourcePosition(call).ScriptOffset();
  uint32_t func = FindOriginatingFunction(call);
  auto maybe_feedback =
      module()->type_feedback.feedback_for_function.find(func);
  if (maybe_feedback == module()->type_feedback.feedback_for_function.end()) {
    return 0;
  }
  wasm::FunctionTypeFeedback feedback = maybe_feedback->second;
  // It's possible that we haven't processed the feedback yet. Currently,
  // this can happen for targets of call_direct that haven't gotten hot yet,
  // and for functions where Liftoff bailed out.
  if (feedback.feedback_vector.size() == 0) return 0;
  auto index_in_vector = feedback.positions.find(position);
  if (index_in_vector == feedback.positions.end()) return 0;
  return feedback.feedback_vector[index_in_vector->second]
      .absolute_call_frequency;
}
73

74 75
// TODO(12166): Save inlined frames for trap/--trace-wasm purposes. Consider
//              tail calls.
76
Reduction WasmInliner::ReduceCall(Node* call) {
77 78
  DCHECK(call->opcode() == IrOpcode::kCall ||
         call->opcode() == IrOpcode::kTailCall);
79 80 81 82 83 84 85 86

  if (seen_.find(call) != seen_.end()) {
    TRACE("function %d: have already seen node %d, skipping\n", function_index_,
          call->id());
    return NoChange();
  }
  seen_.insert(call);

87 88 89 90
  Node* callee = NodeProperties::GetValueInput(call, 0);
  IrOpcode::Value reloc_opcode = mcgraph_->machine()->Is32()
                                     ? IrOpcode::kRelocatableInt32Constant
                                     : IrOpcode::kRelocatableInt64Constant;
91 92 93 94 95
  if (callee->opcode() != reloc_opcode) {
    TRACE("[function %d: considering node %d... not a relocatable constant]\n",
          function_index_, call->id());
    return NoChange();
  }
96
  auto info = OpParameter<RelocatablePtrConstantInfo>(callee->op());
97
  uint32_t inlinee_index = static_cast<uint32_t>(info.value());
98
  if (info.rmode() != RelocInfo::WASM_CALL) {
99
    Trace(call, inlinee_index, "not a wasm call");
100 101 102
    return NoChange();
  }
  if (inlinee_index < module()->num_imported_functions) {
103
    Trace(call, inlinee_index, "imported function");
104 105
    return NoChange();
  }
106
  if (inlinee_index == function_index_) {
107
    Trace(call, inlinee_index, "recursive call");
108 109
    return NoChange();
  }
110

111 112 113
  Trace(call, inlinee_index, "adding to inlining candidates!");

  int call_count = GetCallCount(call);
114 115 116 117

  CHECK_LT(inlinee_index, module()->functions.size());
  const wasm::WasmFunction* inlinee = &module()->functions[inlinee_index];
  base::Vector<const byte> function_bytes = wire_bytes_->GetCode(inlinee->code);
118

119 120
  CandidateInfo candidate{call, inlinee_index, call_count,
                          function_bytes.length()};
121

122 123 124 125
  inlining_candidates_.push(candidate);
  return NoChange();
}

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
bool SmallEnoughToInline(size_t current_graph_size, uint32_t candidate_size) {
  if (WasmInliner::graph_size_allows_inlining(current_graph_size)) {
    return true;
  }
  // For truly tiny functions, let's be a bit more generous.
  return candidate_size < 10 &&
         WasmInliner::graph_size_allows_inlining(current_graph_size - 100);
}

void WasmInliner::Trace(const CandidateInfo& candidate, const char* decision) {
  TRACE(
      "  [function %d: considering candidate {@%d, index=%d, count=%d, "
      "size=%d}: %s]\n",
      function_index_, candidate.node->id(), candidate.inlinee_index,
      candidate.call_count, candidate.wire_byte_size, decision);
}

143
void WasmInliner::Finalize() {
144 145 146
  TRACE("function %d %s: going though inlining candidates...\n",
        function_index_, debug_name_);
  if (inlining_candidates_.empty()) return;
147 148 149 150 151
  while (!inlining_candidates_.empty()) {
    CandidateInfo candidate = inlining_candidates_.top();
    inlining_candidates_.pop();
    Node* call = candidate.node;
    if (call->IsDead()) {
152 153 154 155 156 157 158 159 160 161 162 163 164
      Trace(candidate, "dead node");
      continue;
    }
    int min_count_for_inlining = candidate.wire_byte_size / 2;
    if (candidate.call_count < min_count_for_inlining) {
      Trace(candidate, "not called often enough");
      continue;
    }
    // We could build the candidate's graph first and consider its node count,
    // but it turns out that wire byte size and node count are quite strongly
    // correlated, at about 1.16 nodes per wire byte (measured for J2Wasm).
    if (!SmallEnoughToInline(current_graph_size_, candidate.wire_byte_size)) {
      Trace(candidate, "not enough inlining budget");
165 166 167 168 169 170
      continue;
    }
    const wasm::WasmFunction* inlinee =
        &module()->functions[candidate.inlinee_index];
    base::Vector<const byte> function_bytes =
        wire_bytes_->GetCode(inlinee->code);
171 172 173 174 175 176
    // We use the signature based on the real argument types stored in the call
    // node. This is more specific than the callee's formal signature and might
    // enable some optimizations.
    const wasm::FunctionSig* real_sig =
        CallDescriptorOf(call->op())->wasm_sig();

177 178 179 180 181 182
#if DEBUG
    // Check that the real signature is a subtype of the formal one.
    const wasm::FunctionSig* formal_sig =
        WasmGraphBuilder::Int64LoweredSig(zone(), inlinee->sig);
    CHECK_EQ(real_sig->parameter_count(), formal_sig->parameter_count());
    CHECK_EQ(real_sig->return_count(), formal_sig->return_count());
183
    for (size_t i = 0; i < real_sig->parameter_count(); i++) {
184 185
      CHECK(wasm::IsSubtypeOf(real_sig->GetParam(i), formal_sig->GetParam(i),
                              module()));
186 187
    }
    for (size_t i = 0; i < real_sig->return_count(); i++) {
188 189
      CHECK(wasm::IsSubtypeOf(formal_sig->GetReturn(i), real_sig->GetReturn(i),
                              module()));
190
    }
191
#endif
192 193

    const wasm::FunctionBody inlinee_body(real_sig, inlinee->code.offset(),
194 195 196 197 198
                                          function_bytes.begin(),
                                          function_bytes.end());
    wasm::WasmFeatures detected;
    WasmGraphBuilder builder(env_, zone(), mcgraph_, inlinee_body.sig,
                             source_positions_);
199
    std::vector<WasmLoopInfo> inlinee_loop_infos;
200 201 202 203 204 205

    size_t subgraph_min_node_id = graph()->NodeCount();
    Node* inlinee_start;
    Node* inlinee_end;
    {
      Graph::SubgraphScope scope(graph());
206 207 208 209 210 211 212
      wasm::DecodeResult result = wasm::BuildTFGraph(
          zone()->allocator(), env_->enabled_features, module(), &builder,
          &detected, inlinee_body, &inlinee_loop_infos, node_origins_,
          candidate.inlinee_index,
          NodeProperties::IsExceptionalCall(call)
              ? wasm::kInlinedHandledCall
              : wasm::kInlinedNonHandledCall);
213 214 215
      if (result.failed()) {
        // This can happen if the inlinee has never been compiled before and is
        // invalid. Return, as there is no point to keep optimizing.
216 217 218 219 220 221 222 223 224 225

        // TODO(jkummerow): This can also happen as a consequence of the
        // opportunistic signature specialization we did above! When parameters
        // are reassigned (as locals), the subtypes can make that invalid.
        // Fix this for now by detecting when it happens and retrying the
        // inlining with the original signature.
        // A better long-term fix would be to port check elimination to the
        // TF graph, so we won't need the signature "trick" and more.

        Trace(candidate, "failed to compile");
226 227 228 229
        return;
      }

      builder.LowerInt64(WasmGraphBuilder::kCalledFromWasm);
230 231 232 233 234
      inlinee_start = graph()->start();
      inlinee_end = graph()->end();
    }

    size_t additional_nodes = graph()->NodeCount() - subgraph_min_node_id;
235
    Trace(candidate, "inlining!");
236
    current_graph_size_ += additional_nodes;
237 238 239
    inlined_functions_.push_back(candidate.inlinee_index);
    static_assert(std::is_same_v<NodeId, uint32_t>);
    first_node_id_.push_back(static_cast<uint32_t>(subgraph_min_node_id));
240 241 242 243 244 245 246

    if (call->opcode() == IrOpcode::kCall) {
      InlineCall(call, inlinee_start, inlinee_end, inlinee->sig,
                 subgraph_min_node_id);
    } else {
      InlineTailCall(call, inlinee_start, inlinee_end);
    }
247
    call->Kill();
248 249
    loop_infos_->insert(loop_infos_->end(), inlinee_loop_infos.begin(),
                        inlinee_loop_infos.end());
250 251
    // Returning after only one inlining has been tried and found worse.
  }
252 253
}

254 255 256 257 258
/* Rewire callee formal parameters to the call-site real parameters. Rewire
 * effect and control dependencies of callee's start node with the respective
 * inputs of the call node.
 */
void WasmInliner::RewireFunctionEntry(Node* call, Node* callee_start) {
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
  Node* control = NodeProperties::GetControlInput(call);
  Node* effect = NodeProperties::GetEffectInput(call);

  for (Edge edge : callee_start->use_edges()) {
    Node* use = edge.from();
    switch (use->opcode()) {
      case IrOpcode::kParameter: {
        // Index 0 is the callee node.
        int index = 1 + ParameterIndexOf(use->op());
        Replace(use, NodeProperties::GetValueInput(call, index));
        break;
      }
      default:
        if (NodeProperties::IsEffectEdge(edge)) {
          edge.UpdateTo(effect);
        } else if (NodeProperties::IsControlEdge(edge)) {
275 276 277 278 279
          // Projections pointing to the inlinee start are floating control.
          // They should point to the graph's start.
          edge.UpdateTo(use->opcode() == IrOpcode::kProjection
                            ? graph()->start()
                            : control);
280 281 282
        } else {
          UNREACHABLE();
        }
283
        Revisit(edge.from());
284 285 286
        break;
    }
  }
287 288
}

289 290
void WasmInliner::InlineTailCall(Node* call, Node* callee_start,
                                 Node* callee_end) {
291
  DCHECK_EQ(call->opcode(), IrOpcode::kTailCall);
292 293 294 295 296 297 298
  // 1) Rewire function entry.
  RewireFunctionEntry(call, callee_start);
  // 2) For tail calls, all we have to do is rewire all terminators of the
  // inlined graph to the end of the caller graph.
  for (Node* const input : callee_end->inputs()) {
    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
    NodeProperties::MergeControlToEnd(graph(), common(), input);
299 300 301 302
  }
  for (Edge edge_to_end : call->use_edges()) {
    DCHECK_EQ(edge_to_end.from(), graph()->end());
    edge_to_end.UpdateTo(mcgraph()->Dead());
303 304
  }
  callee_end->Kill();
305 306
  call->Kill();
  Revisit(graph()->end());
307 308
}

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
namespace {
// graph-builder-interface generates a dangling exception handler for each
// throwing call in the inlinee. This might be followed by a LoopExit node.
Node* DanglingHandler(Node* call) {
  Node* if_exception = nullptr;
  for (Node* use : call->uses()) {
    if (use->opcode() == IrOpcode::kIfException) {
      if_exception = use;
      break;
    }
  }
  DCHECK_NOT_NULL(if_exception);

  // If this handler is dangling, return it.
  if (if_exception->UseCount() == 0) return if_exception;

  for (Node* use : if_exception->uses()) {
    // Otherwise, look for a LoopExit use of this handler.
    if (use->opcode() == IrOpcode::kLoopExit) {
      for (Node* loop_exit_use : use->uses()) {
        if (loop_exit_use->opcode() != IrOpcode::kLoopExitEffect &&
            loop_exit_use->opcode() != IrOpcode::kLoopExitValue) {
          // This LoopExit has a use other than LoopExitEffect/Value, so it is
          // not dangling.
          return nullptr;
        }
      }
      return use;
    }
  }

  return nullptr;
}
}  // namespace

344 345 346
void WasmInliner::InlineCall(Node* call, Node* callee_start, Node* callee_end,
                             const wasm::FunctionSig* inlinee_sig,
                             size_t subgraph_min_node_id) {
347
  DCHECK_EQ(call->opcode(), IrOpcode::kCall);
348 349 350 351

  // 0) Before doing anything, if {call} has an exception handler, collect all
  // unhandled calls in the subgraph.
  Node* handler = nullptr;
352
  std::vector<Node*> dangling_handlers;
353 354 355 356
  if (NodeProperties::IsExceptionalCall(call, &handler)) {
    AllNodes subgraph_nodes(zone(), callee_end, graph());
    for (Node* node : subgraph_nodes.reachable) {
      if (node->id() >= subgraph_min_node_id &&
357 358 359 360 361
          !node->op()->HasProperty(Operator::kNoThrow)) {
        Node* dangling_handler = DanglingHandler(node);
        if (dangling_handler != nullptr) {
          dangling_handlers.push_back(dangling_handler);
        }
362 363 364 365
      }
    }
  }

366 367
  // 1) Rewire function entry.
  RewireFunctionEntry(call, callee_start);
368

369
  // 2) Handle all graph terminators for the callee.
370 371 372 373 374
  NodeVector return_nodes(zone());
  for (Node* const input : callee_end->inputs()) {
    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
    switch (input->opcode()) {
      case IrOpcode::kReturn:
375
        // Returns are collected to be rewired into the caller graph later.
376 377 378 379 380 381 382 383
        return_nodes.push_back(input);
        break;
      case IrOpcode::kDeoptimize:
      case IrOpcode::kTerminate:
      case IrOpcode::kThrow:
        NodeProperties::MergeControlToEnd(graph(), common(), input);
        Revisit(graph()->end());
        break;
384
      case IrOpcode::kTailCall: {
385 386 387
        // A tail call in the callee inlined in a regular call in the caller has
        // to be transformed into a regular call, and then returned from the
        // inlinee. It will then be handled like any other return.
388 389
        auto descriptor = CallDescriptorOf(input->op());
        NodeProperties::ChangeOp(input, common()->Call(descriptor));
390
        int return_arity = static_cast<int>(inlinee_sig->return_count());
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
        NodeVector return_inputs(zone());
        // The first input of a return node is always the 0 constant.
        return_inputs.push_back(graph()->NewNode(common()->Int32Constant(0)));
        if (return_arity == 1) {
          return_inputs.push_back(input);
        } else if (return_arity > 1) {
          for (int i = 0; i < return_arity; i++) {
            return_inputs.push_back(
                graph()->NewNode(common()->Projection(i), input, input));
          }
        }

        // Add effect and control inputs.
        return_inputs.push_back(input->op()->EffectOutputCount() > 0
                                    ? input
                                    : NodeProperties::GetEffectInput(input));
        return_inputs.push_back(input->op()->ControlOutputCount() > 0
                                    ? input
                                    : NodeProperties::GetControlInput(input));

        Node* ret = graph()->NewNode(common()->Return(return_arity),
                                     static_cast<int>(return_inputs.size()),
                                     return_inputs.data());
        return_nodes.push_back(ret);
        break;
      }
417 418 419 420
      default:
        UNREACHABLE();
    }
  }
421
  callee_end->Kill();
422

423
  // 3) Rewire unhandled calls to the handler.
424
  int handler_count = static_cast<int>(dangling_handlers.size());
425

426
  if (handler_count > 0) {
427
    Node* control_output =
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
        graph()->NewNode(common()->Merge(handler_count), handler_count,
                         dangling_handlers.data());
    std::vector<Node*> effects;
    std::vector<Node*> values;
    for (Node* control : dangling_handlers) {
      if (control->opcode() == IrOpcode::kIfException) {
        effects.push_back(control);
        values.push_back(control);
      } else {
        DCHECK_EQ(control->opcode(), IrOpcode::kLoopExit);
        Node* if_exception = control->InputAt(0);
        DCHECK_EQ(if_exception->opcode(), IrOpcode::kIfException);
        effects.push_back(graph()->NewNode(common()->LoopExitEffect(),
                                           if_exception, control));
        values.push_back(graph()->NewNode(
            common()->LoopExitValue(MachineRepresentation::kTagged),
            if_exception, control));
      }
    }

    effects.push_back(control_output);
    values.push_back(control_output);
450
    Node* value_output = graph()->NewNode(
451 452 453 454
        common()->Phi(MachineRepresentation::kTagged, handler_count),
        handler_count + 1, values.data());
    Node* effect_output = graph()->NewNode(common()->EffectPhi(handler_count),
                                           handler_count + 1, effects.data());
455 456 457 458 459 460 461
    ReplaceWithValue(handler, value_output, effect_output, control_output);
  } else if (handler != nullptr) {
    // Nothing in the inlined function can throw. Remove the handler.
    ReplaceWithValue(handler, mcgraph()->Dead(), mcgraph()->Dead(),
                     mcgraph()->Dead());
  }

462
  if (return_nodes.size() > 0) {
463
    /* 4) Collect all return site value, effect, and control inputs into phis
464
     * and merges. */
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
    int const return_count = static_cast<int>(return_nodes.size());
    NodeVector controls(zone());
    NodeVector effects(zone());
    for (Node* const return_node : return_nodes) {
      controls.push_back(NodeProperties::GetControlInput(return_node));
      effects.push_back(NodeProperties::GetEffectInput(return_node));
    }
    Node* control_output = graph()->NewNode(common()->Merge(return_count),
                                            return_count, &controls.front());
    effects.push_back(control_output);
    Node* effect_output =
        graph()->NewNode(common()->EffectPhi(return_count),
                         static_cast<int>(effects.size()), &effects.front());

    // The first input of a return node is discarded. This is because Wasm
    // functions always return an additional 0 constant as a first return value.
    DCHECK(
        Int32Matcher(NodeProperties::GetValueInput(return_nodes[0], 0)).Is(0));
    int const return_arity = return_nodes[0]->op()->ValueInputCount() - 1;
    NodeVector values(zone());
    for (int i = 0; i < return_arity; i++) {
      NodeVector ith_values(zone());
      for (Node* const return_node : return_nodes) {
        Node* value = NodeProperties::GetValueInput(return_node, i + 1);
        ith_values.push_back(value);
      }
      ith_values.push_back(control_output);
      // Find the correct machine representation for the return values from the
      // inlinee signature.
      MachineRepresentation repr =
495
          inlinee_sig->GetReturn(i).machine_representation();
496 497 498 499 500
      Node* ith_value_output = graph()->NewNode(
          common()->Phi(repr, return_count),
          static_cast<int>(ith_values.size()), &ith_values.front());
      values.push_back(ith_value_output);
    }
501
    for (Node* return_node : return_nodes) return_node->Kill();
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531

    if (return_arity == 0) {
      // Void function, no value uses.
      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
    } else if (return_arity == 1) {
      // One return value. Just replace value uses of the call node with it.
      ReplaceWithValue(call, values[0], effect_output, control_output);
    } else {
      // Multiple returns. We have to find the projections of the call node and
      // replace them with the returned values.
      for (Edge use_edge : call->use_edges()) {
        if (NodeProperties::IsValueEdge(use_edge)) {
          Node* use = use_edge.from();
          DCHECK_EQ(use->opcode(), IrOpcode::kProjection);
          ReplaceWithValue(use, values[ProjectionIndexOf(use->op())]);
        }
      }
      // All value inputs are replaced by the above loop, so it is ok to use
      // Dead() as a dummy for value replacement.
      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
    }
  } else {
    // The callee can never return. The call node and all its uses are dead.
    ReplaceWithValue(call, mcgraph()->Dead(), mcgraph()->Dead(),
                     mcgraph()->Dead());
  }
}

const wasm::WasmModule* WasmInliner::module() const { return env_->module; }

532 533
#undef TRACE

534 535 536
}  // namespace compiler
}  // namespace internal
}  // namespace v8