Commit 70a452ff authored by Manos Koukoutos's avatar Manos Koukoutos Committed by V8 LUCI CQ

[wasm-gc] Optimize away nominal type upcasts

We optimize away type upcasts for nominal types in WasmFullDecoder.
Upcasts trivially hold for nominal types, which is not the case for
structural types. Note that we already optimize away trivially-failing
checks (when types are unrelated) for both nominal and structural types.

Bug: v8:7748
Change-Id: I720c9803cb8b4071aa4bae112ce06d587b7a68fa
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3306984
Commit-Queue: Manos Koukoutos <manoskouk@chromium.org>
Reviewed-by: 's avatarJakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/main@{#78201}
parent c2f7f596
...@@ -1316,6 +1316,10 @@ Node* WasmGraphBuilder::Unop(wasm::WasmOpcode opcode, Node* input, ...@@ -1316,6 +1316,10 @@ Node* WasmGraphBuilder::Unop(wasm::WasmOpcode opcode, Node* input,
: BuildIntConvertFloat(input, position, opcode); : BuildIntConvertFloat(input, position, opcode);
case wasm::kExprRefIsNull: case wasm::kExprRefIsNull:
return IsNull(input); return IsNull(input);
// We abuse ref.as_non_null, which isn't otherwise used in this switch, as
// a sentinel for the negation of ref.is_null.
case wasm::kExprRefAsNonNull:
return gasm_->Int32Sub(gasm_->Int32Constant(1), IsNull(input));
case wasm::kExprI32AsmjsLoadMem8S: case wasm::kExprI32AsmjsLoadMem8S:
return BuildAsmjsLoadMem(MachineType::Int8(), input); return BuildAsmjsLoadMem(MachineType::Int8(), input);
case wasm::kExprI32AsmjsLoadMem8U: case wasm::kExprI32AsmjsLoadMem8U:
......
...@@ -1736,7 +1736,10 @@ class LiftoffCompiler { ...@@ -1736,7 +1736,10 @@ class LiftoffCompiler {
__ emit_type_conversion(kExprI64UConvertI32, dst, c_call_dst, __ emit_type_conversion(kExprI64UConvertI32, dst, c_call_dst,
nullptr); nullptr);
}); });
case kExprRefIsNull: { case kExprRefIsNull:
// We abuse ref.as_non_null, which isn't otherwise used in this switch, as
// a sentinel for the negation of ref.is_null.
case kExprRefAsNonNull: {
LiftoffRegList pinned; LiftoffRegList pinned;
LiftoffRegister ref = pinned.set(__ PopToRegister()); LiftoffRegister ref = pinned.set(__ PopToRegister());
LiftoffRegister null = __ GetUnusedRegister(kGpReg, pinned); LiftoffRegister null = __ GetUnusedRegister(kGpReg, pinned);
...@@ -1744,7 +1747,8 @@ class LiftoffCompiler { ...@@ -1744,7 +1747,8 @@ class LiftoffCompiler {
// Prefer to overwrite one of the input registers with the result // Prefer to overwrite one of the input registers with the result
// of the comparison. // of the comparison.
LiftoffRegister dst = __ GetUnusedRegister(kGpReg, {ref, null}, {}); LiftoffRegister dst = __ GetUnusedRegister(kGpReg, {ref, null}, {});
__ emit_ptrsize_set_cond(kEqual, dst.gp(), ref, null); __ emit_ptrsize_set_cond(opcode == kExprRefIsNull ? kEqual : kUnequal,
dst.gp(), ref, null);
__ PushRegister(kI32, dst); __ PushRegister(kI32, dst);
return; return;
} }
...@@ -3347,7 +3351,9 @@ class LiftoffCompiler { ...@@ -3347,7 +3351,9 @@ class LiftoffCompiler {
CallRef(decoder, func_ref.type, sig, kTailCall); CallRef(decoder, func_ref.type, sig, kTailCall);
} }
void BrOnNull(FullDecoder* decoder, const Value& ref_object, uint32_t depth) { void BrOnNull(FullDecoder* decoder, const Value& ref_object, uint32_t depth,
bool pass_null_along_branch,
Value* /* result_on_fallthrough */) {
// Before branching, materialize all constants. This avoids repeatedly // Before branching, materialize all constants. This avoids repeatedly
// materializing them for each conditional branch. // materializing them for each conditional branch.
if (depth != decoder->control_depth() - 1) { if (depth != decoder->control_depth() - 1) {
...@@ -3362,7 +3368,7 @@ class LiftoffCompiler { ...@@ -3362,7 +3368,7 @@ class LiftoffCompiler {
LoadNullValue(null, pinned); LoadNullValue(null, pinned);
__ emit_cond_jump(kUnequal, &cont_false, ref_object.type.kind(), ref.gp(), __ emit_cond_jump(kUnequal, &cont_false, ref_object.type.kind(), ref.gp(),
null); null);
if (pass_null_along_branch) LoadNullValue(null, pinned);
BrOrRet(decoder, depth, 0); BrOrRet(decoder, depth, 0);
__ bind(&cont_false); __ bind(&cont_false);
__ PushRegister(kRef, ref); __ PushRegister(kRef, ref);
......
...@@ -1018,7 +1018,8 @@ struct ControlBase : public PcForErrors<validate> { ...@@ -1018,7 +1018,8 @@ struct ControlBase : public PcForErrors<validate> {
const Value args[]) \ const Value args[]) \
F(ReturnCallIndirect, const Value& index, \ F(ReturnCallIndirect, const Value& index, \
const CallIndirectImmediate<validate>& imm, const Value args[]) \ const CallIndirectImmediate<validate>& imm, const Value args[]) \
F(BrOnNull, const Value& ref_object, uint32_t depth) \ F(BrOnNull, const Value& ref_object, uint32_t depth, \
bool pass_null_along_branch, Value* result_on_fallthrough) \
F(BrOnNonNull, const Value& ref_object, uint32_t depth) \ F(BrOnNonNull, const Value& ref_object, uint32_t depth) \
F(SimdOp, WasmOpcode opcode, base::Vector<Value> args, Value* result) \ F(SimdOp, WasmOpcode opcode, base::Vector<Value> args, Value* result) \
F(SimdLaneOp, WasmOpcode opcode, const SimdLaneImmediate<validate>& imm, \ F(SimdLaneOp, WasmOpcode opcode, const SimdLaneImmediate<validate>& imm, \
...@@ -2729,8 +2730,7 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -2729,8 +2730,7 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
// The result of br_on_null has the same value as the argument (but a // The result of br_on_null has the same value as the argument (but a
// non-nullable type). // non-nullable type).
if (V8_LIKELY(current_code_reachable_and_ok_)) { if (V8_LIKELY(current_code_reachable_and_ok_)) {
CALL_INTERFACE(BrOnNull, ref_object, imm.depth); CALL_INTERFACE(BrOnNull, ref_object, imm.depth, false, &result);
CALL_INTERFACE(Forward, ref_object, &result);
c->br_merge()->reached = true; c->br_merge()->reached = true;
} }
// In unreachable code, we still have to push a value of the correct // In unreachable code, we still have to push a value of the correct
...@@ -4017,9 +4017,21 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4017,9 +4017,21 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
} }
} }
bool ObjectRelatedWithRtt(Value obj, Value rtt) { // Checks if types are unrelated, thus type checking will always fail. Does
return IsSubtypeOf(ValueType::Ref(rtt.type.ref_index(), kNonNullable), // not account for nullability.
obj.type, this->module_) || bool TypeCheckAlwaysFails(Value obj, Value rtt) {
return !IsSubtypeOf(ValueType::Ref(rtt.type.ref_index(), kNonNullable),
obj.type, this->module_) &&
!IsSubtypeOf(obj.type,
ValueType::Ref(rtt.type.ref_index(), kNullable),
this->module_);
}
// Checks it {obj} is a nominal type which is a subtype of {rtt}'s index, thus
// checking will always succeed. Does not account for nullability.
bool TypeCheckAlwaysSucceeds(Value obj, Value rtt) {
return obj.type.has_index() &&
this->module_->has_supertype(obj.type.ref_index()) &&
IsSubtypeOf(obj.type, IsSubtypeOf(obj.type,
ValueType::Ref(rtt.type.ref_index(), kNullable), ValueType::Ref(rtt.type.ref_index(), kNullable),
this->module_); this->module_);
...@@ -4503,13 +4515,24 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4503,13 +4515,24 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
if (current_code_reachable_and_ok_) { if (current_code_reachable_and_ok_) {
// This logic ensures that code generation can assume that functions // This logic ensures that code generation can assume that functions
// can only be cast to function types, and data objects to data types. // can only be cast to function types, and data objects to data types.
if (V8_LIKELY(ObjectRelatedWithRtt(obj, rtt))) { if (V8_UNLIKELY(TypeCheckAlwaysSucceeds(obj, rtt))) {
CALL_INTERFACE(RefTest, obj, rtt, &value); // Drop rtt.
} else { CALL_INTERFACE(Drop);
// Type checking can still fail for null.
if (obj.type.is_nullable()) {
// We abuse ref.as_non_null, which isn't otherwise used as a unary
// operator, as a sentinel for the negation of ref.is_null.
CALL_INTERFACE(UnOp, kExprRefAsNonNull, obj, &value);
} else {
CALL_INTERFACE(Drop);
CALL_INTERFACE(I32Const, &value, 1);
}
} else if (V8_UNLIKELY(TypeCheckAlwaysFails(obj, rtt))) {
CALL_INTERFACE(Drop); CALL_INTERFACE(Drop);
CALL_INTERFACE(Drop); CALL_INTERFACE(Drop);
// Unrelated types. Will always fail.
CALL_INTERFACE(I32Const, &value, 0); CALL_INTERFACE(I32Const, &value, 0);
} else {
CALL_INTERFACE(RefTest, obj, rtt, &value);
} }
} }
Drop(2); Drop(2);
...@@ -4556,9 +4579,12 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4556,9 +4579,12 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
if (current_code_reachable_and_ok_) { if (current_code_reachable_and_ok_) {
// This logic ensures that code generation can assume that functions // This logic ensures that code generation can assume that functions
// can only be cast to function types, and data objects to data types. // can only be cast to function types, and data objects to data types.
if (V8_LIKELY(ObjectRelatedWithRtt(obj, rtt))) { if (V8_UNLIKELY(TypeCheckAlwaysSucceeds(obj, rtt))) {
CALL_INTERFACE(RefCast, obj, rtt, &value); // Drop the rtt from the stack, then forward the object value to the
} else { // result.
CALL_INTERFACE(Drop);
CALL_INTERFACE(Forward, obj, &value);
} else if (V8_UNLIKELY(TypeCheckAlwaysFails(obj, rtt))) {
// Unrelated types. The only way this will not trap is if the object // Unrelated types. The only way this will not trap is if the object
// is null. // is null.
if (obj.type.is_nullable()) { if (obj.type.is_nullable()) {
...@@ -4569,6 +4595,8 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4569,6 +4595,8 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
CALL_INTERFACE(Trap, TrapReason::kTrapIllegalCast); CALL_INTERFACE(Trap, TrapReason::kTrapIllegalCast);
EndControl(); EndControl();
} }
} else {
CALL_INTERFACE(RefCast, obj, rtt, &value);
} }
} }
Drop(2); Drop(2);
...@@ -4628,20 +4656,30 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4628,20 +4656,30 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
: ValueType::Ref(rtt.type.ref_index(), kNonNullable)); : ValueType::Ref(rtt.type.ref_index(), kNonNullable));
Push(result_on_branch); Push(result_on_branch);
if (!VALIDATE(TypeCheckBranch<true>(c, 0))) return 0; if (!VALIDATE(TypeCheckBranch<true>(c, 0))) return 0;
// This logic ensures that code generation can assume that functions if (V8_LIKELY(current_code_reachable_and_ok_)) {
// can only be cast to function types, and data objects to data types. // This logic ensures that code generation can assume that functions
if (V8_LIKELY(ObjectRelatedWithRtt(obj, rtt))) { // can only be cast to function types, and data objects to data types.
// The {value_on_branch} parameter we pass to the interface must if (V8_UNLIKELY(TypeCheckAlwaysSucceeds(obj, rtt))) {
// be pointer-identical to the object on the stack, so we can't CALL_INTERFACE(Drop); // rtt
// reuse {result_on_branch} which was passed-by-value to {Push}. // The branch will still not be taken on null.
Value* value_on_branch = stack_value(1); if (obj.type.is_nullable()) {
if (V8_LIKELY(current_code_reachable_and_ok_)) { CALL_INTERFACE(BrOnNonNull, obj, branch_depth.depth);
} else {
CALL_INTERFACE(BrOrRet, branch_depth.depth, 0);
}
c->br_merge()->reached = true;
} else if (V8_LIKELY(!TypeCheckAlwaysFails(obj, rtt))) {
// The {value_on_branch} parameter we pass to the interface must
// be pointer-identical to the object on the stack, so we can't
// reuse {result_on_branch} which was passed-by-value to {Push}.
Value* value_on_branch = stack_value(1);
CALL_INTERFACE(BrOnCast, obj, rtt, value_on_branch, CALL_INTERFACE(BrOnCast, obj, rtt, value_on_branch,
branch_depth.depth); branch_depth.depth);
c->br_merge()->reached = true; c->br_merge()->reached = true;
} }
// Otherwise the types are unrelated. Do not branch.
} }
// Otherwise the types are unrelated. Do not branch.
Drop(result_on_branch); Drop(result_on_branch);
Push(obj); // Restore stack state on fallthrough. Push(obj); // Restore stack state on fallthrough.
return opcode_length + branch_depth.length; return opcode_length + branch_depth.length;
...@@ -4699,13 +4737,10 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4699,13 +4737,10 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
rtt.type.is_bottom() rtt.type.is_bottom()
? kWasmBottom ? kWasmBottom
: ValueType::Ref(rtt.type.ref_index(), kNonNullable)); : ValueType::Ref(rtt.type.ref_index(), kNonNullable));
// This logic ensures that code generation can assume that functions
// can only be cast to function types, and data objects to data types.
if (V8_LIKELY(current_code_reachable_and_ok_)) { if (V8_LIKELY(current_code_reachable_and_ok_)) {
if (V8_LIKELY(ObjectRelatedWithRtt(obj, rtt))) { // This logic ensures that code generation can assume that functions
CALL_INTERFACE(BrOnCastFail, obj, rtt, &result_on_fallthrough, // can only be cast to function types, and data objects to data types.
branch_depth.depth); if (V8_UNLIKELY(TypeCheckAlwaysFails(obj, rtt))) {
} else {
// Drop {rtt} in the interface. // Drop {rtt} in the interface.
CALL_INTERFACE(Drop); CALL_INTERFACE(Drop);
// Otherwise the types are unrelated. Always branch. // Otherwise the types are unrelated. Always branch.
...@@ -4713,8 +4748,25 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> { ...@@ -4713,8 +4748,25 @@ class WasmFullDecoder : public WasmDecoder<validate, decoding_mode> {
// We know that the following code is not reachable, but according // We know that the following code is not reachable, but according
// to the spec it technically is. Set it to spec-only reachable. // to the spec it technically is. Set it to spec-only reachable.
SetSucceedingCodeDynamicallyUnreachable(); SetSucceedingCodeDynamicallyUnreachable();
c->br_merge()->reached = true;
} else if (V8_UNLIKELY(TypeCheckAlwaysSucceeds(obj, rtt))) {
// Drop {rtt} in the interface.
CALL_INTERFACE(Drop);
// The branch can still be taken on null.
if (obj.type.is_nullable()) {
CALL_INTERFACE(BrOnNull, obj, branch_depth.depth, true,
&result_on_fallthrough);
c->br_merge()->reached = true;
} else {
// Drop {obj} in the interface.
CALL_INTERFACE(Drop);
}
} else {
CALL_INTERFACE(BrOnCastFail, obj, rtt, &result_on_fallthrough,
branch_depth.depth);
c->br_merge()->reached = true;
} }
c->br_merge()->reached = true; // Otherwise, the type check always succeeds. Do not branch.
} }
// Make sure the correct value is on the stack state on fallthrough. // Make sure the correct value is on the stack state on fallthrough.
Drop(obj); Drop(obj);
......
...@@ -784,7 +784,8 @@ class WasmGraphBuildingInterface { ...@@ -784,7 +784,8 @@ class WasmGraphBuildingInterface {
args); args);
} }
void BrOnNull(FullDecoder* decoder, const Value& ref_object, uint32_t depth) { void BrOnNull(FullDecoder* decoder, const Value& ref_object, uint32_t depth,
bool pass_null_along_branch, Value* result_on_fallthrough) {
SsaEnv* false_env = ssa_env_; SsaEnv* false_env = ssa_env_;
SsaEnv* true_env = Split(decoder->zone(), false_env); SsaEnv* true_env = Split(decoder->zone(), false_env);
false_env->SetNotMerged(); false_env->SetNotMerged();
...@@ -792,8 +793,9 @@ class WasmGraphBuildingInterface { ...@@ -792,8 +793,9 @@ class WasmGraphBuildingInterface {
&false_env->control); &false_env->control);
builder_->SetControl(false_env->control); builder_->SetControl(false_env->control);
SetEnv(true_env); SetEnv(true_env);
BrOrRet(decoder, depth, 1); BrOrRet(decoder, depth, pass_null_along_branch ? 0 : 1);
SetEnv(false_env); SetEnv(false_env);
result_on_fallthrough->node = ref_object.node;
} }
void BrOnNonNull(FullDecoder* decoder, const Value& ref_object, void BrOnNonNull(FullDecoder* decoder, const Value& ref_object,
......
...@@ -1446,6 +1446,7 @@ WASM_COMPILED_EXEC_TEST(RttFreshSub) { ...@@ -1446,6 +1446,7 @@ WASM_COMPILED_EXEC_TEST(RttFreshSub) {
} }
WASM_COMPILED_EXEC_TEST(RefTrivialCasts) { WASM_COMPILED_EXEC_TEST(RefTrivialCasts) {
// TODO(7748): Add tests for branch_on_*.
WasmGCTester tester(execution_tier); WasmGCTester tester(execution_tier);
byte type_index = tester.DefineStruct({F(wasm::kWasmI32, true)}); byte type_index = tester.DefineStruct({F(wasm::kWasmI32, true)});
byte subtype_index = byte subtype_index =
...@@ -1458,6 +1459,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) { ...@@ -1458,6 +1459,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) {
tester.sigs.i_v(), {}, tester.sigs.i_v(), {},
{WASM_REF_TEST(WASM_REF_NULL(type_index), WASM_RTT_CANON(subtype_index)), {WASM_REF_TEST(WASM_REF_NULL(type_index), WASM_RTT_CANON(subtype_index)),
kExprEnd}); kExprEnd});
// Upcasts should not be optimized away for structural types.
const byte kRefTestUpcast = tester.DefineFunction( const byte kRefTestUpcast = tester.DefineFunction(
tester.sigs.i_v(), {}, tester.sigs.i_v(), {},
{WASM_REF_TEST( {WASM_REF_TEST(
...@@ -1466,6 +1468,12 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) { ...@@ -1466,6 +1468,12 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) {
WASM_RTT_SUB(subtype_index, WASM_RTT_CANON(type_index))), WASM_RTT_SUB(subtype_index, WASM_RTT_CANON(type_index))),
WASM_RTT_CANON(type_index)), WASM_RTT_CANON(type_index)),
kExprEnd}); kExprEnd});
const byte kRefTestUpcastFail = tester.DefineFunction(
tester.sigs.i_v(), {},
{WASM_REF_TEST(WASM_STRUCT_NEW_DEFAULT_WITH_RTT(
subtype_index, WASM_RTT_CANON(subtype_index)),
WASM_RTT_CANON(type_index)),
kExprEnd});
const byte kRefTestUpcastNull = tester.DefineFunction( const byte kRefTestUpcastNull = tester.DefineFunction(
tester.sigs.i_v(), {}, tester.sigs.i_v(), {},
{WASM_REF_TEST(WASM_REF_NULL(subtype_index), WASM_RTT_CANON(type_index)), {WASM_REF_TEST(WASM_REF_NULL(subtype_index), WASM_RTT_CANON(type_index)),
...@@ -1532,6 +1540,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) { ...@@ -1532,6 +1540,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) {
tester.CheckResult(kRefTestNull, 0); tester.CheckResult(kRefTestNull, 0);
tester.CheckResult(kRefTestUpcast, 1); tester.CheckResult(kRefTestUpcast, 1);
tester.CheckResult(kRefTestUpcastFail, 0);
tester.CheckResult(kRefTestUpcastNull, 0); tester.CheckResult(kRefTestUpcastNull, 0);
tester.CheckResult(kRefTestUnrelated, 0); tester.CheckResult(kRefTestUnrelated, 0);
tester.CheckResult(kRefTestUnrelatedNull, 0); tester.CheckResult(kRefTestUnrelatedNull, 0);
...@@ -1546,6 +1555,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) { ...@@ -1546,6 +1555,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCasts) {
} }
WASM_COMPILED_EXEC_TEST(RefTrivialCastsStatic) { WASM_COMPILED_EXEC_TEST(RefTrivialCastsStatic) {
// TODO(7748): Add tests for branch_on_*.
WasmGCTester tester(execution_tier); WasmGCTester tester(execution_tier);
byte type_index = byte type_index =
tester.DefineStruct({F(wasm::kWasmI32, true)}, kGenericSuperType); tester.DefineStruct({F(wasm::kWasmI32, true)}, kGenericSuperType);
...@@ -1559,6 +1569,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCastsStatic) { ...@@ -1559,6 +1569,7 @@ WASM_COMPILED_EXEC_TEST(RefTrivialCastsStatic) {
tester.sigs.i_v(), {}, tester.sigs.i_v(), {},
{WASM_REF_TEST_STATIC(WASM_REF_NULL(type_index), subtype_index), {WASM_REF_TEST_STATIC(WASM_REF_NULL(type_index), subtype_index),
kExprEnd}); kExprEnd});
// Upcasts should be optimized away for nominal types.
const byte kRefTestUpcast = tester.DefineFunction( const byte kRefTestUpcast = tester.DefineFunction(
tester.sigs.i_v(), {}, tester.sigs.i_v(), {},
{WASM_REF_TEST_STATIC(WASM_STRUCT_NEW_DEFAULT(subtype_index), type_index), {WASM_REF_TEST_STATIC(WASM_STRUCT_NEW_DEFAULT(subtype_index), type_index),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment