Commit 4b03f024 authored by Manos Koukoutos's avatar Manos Koukoutos Committed by Commit Bot

[wasm-gc] ref.cast forwards null input

According to the new wasm-gc spec, ref.cast should forward a null input
without trapping.

Bug: v8:7748
Change-Id: Ifee17f02a572e7028c14482bc94f0e1c7fc82a5b
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/2647261
Commit-Queue: Manos Koukoutos <manoskouk@chromium.org>
Reviewed-by: 's avatarJakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/master@{#72358}
parent 2919e543
...@@ -5792,12 +5792,11 @@ Node* WasmGraphBuilder::RefCast(Node* object, Node* rtt, ...@@ -5792,12 +5792,11 @@ Node* WasmGraphBuilder::RefCast(Node* object, Node* rtt,
} else { } else {
AssertFalse(mcgraph(), gasm_.get(), gasm_->IsI31(object)); AssertFalse(mcgraph(), gasm_.get(), gasm_->IsI31(object));
} }
auto done = gasm_->MakeLabel();
if (config.object_can_be_null) { if (config.object_can_be_null) {
TrapIfTrue(wasm::kTrapIllegalCast, gasm_->WordEqual(object, RefNull()), gasm_->GotoIf(gasm_->WordEqual(object, RefNull()), &done);
position);
} }
Node* map = gasm_->LoadMap(object); Node* map = gasm_->LoadMap(object);
auto done = gasm_->MakeLabel();
gasm_->GotoIf(gasm_->TaggedEqual(map, rtt), &done); gasm_->GotoIf(gasm_->TaggedEqual(map, rtt), &done);
if (!config.object_must_be_data_ref) { if (!config.object_must_be_data_ref) {
TrapIfFalse(wasm::kTrapIllegalCast, gasm_->IsDataRefMap(map), position); TrapIfFalse(wasm::kTrapIllegalCast, gasm_->IsDataRefMap(map), position);
......
...@@ -4391,10 +4391,16 @@ class LiftoffCompiler { ...@@ -4391,10 +4391,16 @@ class LiftoffCompiler {
__ PushRegister(rtt_value_type, LiftoffRegister(kReturnRegister0)); __ PushRegister(rtt_value_type, LiftoffRegister(kReturnRegister0));
} }
enum NullSucceeds : bool { // --
kNullSucceeds = true,
kNullFails = false
};
// Falls through on match (=successful type check). // Falls through on match (=successful type check).
// Returns the register containing the object. // Returns the register containing the object.
LiftoffRegister SubtypeCheck(FullDecoder* decoder, const Value& obj, LiftoffRegister SubtypeCheck(FullDecoder* decoder, const Value& obj,
const Value& rtt, Label* no_match, const Value& rtt, Label* no_match,
NullSucceeds null_succeeds,
LiftoffRegList pinned = {}, LiftoffRegList pinned = {},
Register opt_scratch = no_reg) { Register opt_scratch = no_reg) {
Label match; Label match;
...@@ -4414,7 +4420,8 @@ class LiftoffCompiler { ...@@ -4414,7 +4420,8 @@ class LiftoffCompiler {
} }
if (obj.type.is_nullable()) { if (obj.type.is_nullable()) {
LoadNullValue(tmp1.gp(), pinned); LoadNullValue(tmp1.gp(), pinned);
__ emit_cond_jump(kEqual, no_match, obj.type, obj_reg.gp(), tmp1.gp()); __ emit_cond_jump(kEqual, null_succeeds ? &match : no_match, obj.type,
obj_reg.gp(), tmp1.gp());
} }
// At this point, the object is neither null nor an i31ref. Perform // At this point, the object is neither null nor an i31ref. Perform
...@@ -4463,7 +4470,8 @@ class LiftoffCompiler { ...@@ -4463,7 +4470,8 @@ class LiftoffCompiler {
LiftoffRegList pinned; LiftoffRegList pinned;
LiftoffRegister result = pinned.set(__ GetUnusedRegister(kGpReg, {})); LiftoffRegister result = pinned.set(__ GetUnusedRegister(kGpReg, {}));
SubtypeCheck(decoder, obj, rtt, &return_false, pinned, result.gp()); SubtypeCheck(decoder, obj, rtt, &return_false, kNullFails, pinned,
result.gp());
__ LoadConstant(result, WasmValue(1)); __ LoadConstant(result, WasmValue(1));
// TODO(jkummerow): Emit near jumps on platforms where it's more efficient. // TODO(jkummerow): Emit near jumps on platforms where it's more efficient.
...@@ -4479,9 +4487,10 @@ class LiftoffCompiler { ...@@ -4479,9 +4487,10 @@ class LiftoffCompiler {
Value* result) { Value* result) {
Label* trap_label = AddOutOfLineTrap(decoder->position(), Label* trap_label = AddOutOfLineTrap(decoder->position(),
WasmCode::kThrowWasmTrapIllegalCast); WasmCode::kThrowWasmTrapIllegalCast);
LiftoffRegister obj_reg = SubtypeCheck(decoder, obj, rtt, trap_label); LiftoffRegister obj_reg =
__ PushRegister(ValueType::Ref(rtt.type.ref_index(), kNonNullable), SubtypeCheck(decoder, obj, rtt, trap_label, kNullSucceeds);
obj_reg); __ PushRegister(
ValueType::Ref(rtt.type.ref_index(), obj.type.nullability()), obj_reg);
} }
void BrOnCast(FullDecoder* decoder, const Value& obj, const Value& rtt, void BrOnCast(FullDecoder* decoder, const Value& obj, const Value& rtt,
...@@ -4494,11 +4503,13 @@ class LiftoffCompiler { ...@@ -4494,11 +4503,13 @@ class LiftoffCompiler {
} }
Label cont_false; Label cont_false;
LiftoffRegister obj_reg = SubtypeCheck(decoder, obj, rtt, &cont_false); LiftoffRegister obj_reg =
SubtypeCheck(decoder, obj, rtt, &cont_false, kNullFails);
__ PushRegister(rtt.type.is_bottom() __ PushRegister(
rtt.type.is_bottom()
? kWasmBottom ? kWasmBottom
: ValueType::Ref(rtt.type.ref_index(), kNonNullable), : ValueType::Ref(rtt.type.ref_index(), obj.type.nullability()),
obj_reg); obj_reg);
BrOrRet(decoder, depth); BrOrRet(decoder, depth);
......
...@@ -4090,7 +4090,6 @@ class WasmFullDecoder : public WasmDecoder<validate> { ...@@ -4090,7 +4090,6 @@ class WasmFullDecoder : public WasmDecoder<validate> {
return 0; return 0;
} }
Value obj = Pop(0, kWasmAnyRef); Value obj = Pop(0, kWasmAnyRef);
Value* value = Push(ValueType::Ref(imm.index, kNonNullable));
if (obj.type != kWasmBottom) { if (obj.type != kWasmBottom) {
if (!VALIDATE(IsSubtypeOf(ValueType::Ref(imm.index, kNonNullable), if (!VALIDATE(IsSubtypeOf(ValueType::Ref(imm.index, kNonNullable),
obj.type, this->module_))) { obj.type, this->module_))) {
...@@ -4098,6 +4097,8 @@ class WasmFullDecoder : public WasmDecoder<validate> { ...@@ -4098,6 +4097,8 @@ class WasmFullDecoder : public WasmDecoder<validate> {
"supertype of type " + std::to_string(imm.index)); "supertype of type " + std::to_string(imm.index));
return 0; return 0;
} }
Value* value =
Push(ValueType::Ref(imm.index, obj.type.nullability()));
CALL_INTERFACE_IF_REACHABLE(RefCast, obj, rtt, value); CALL_INTERFACE_IF_REACHABLE(RefCast, obj, rtt, value);
} }
return opcode_length + imm.length; return opcode_length + imm.length;
......
...@@ -262,6 +262,10 @@ class ValueType { ...@@ -262,6 +262,10 @@ class ValueType {
CONSTEXPR_DCHECK(has_index()); CONSTEXPR_DCHECK(has_index());
return HeapTypeField::decode(bit_field_); return HeapTypeField::decode(bit_field_);
} }
constexpr Nullability nullability() const {
CONSTEXPR_DCHECK(is_object_reference_type());
return kind() == kOptRef ? kNullable : kNonNullable;
}
// Useful when serializing this type to store it into a runtime object. // Useful when serializing this type to store it into a runtime object.
constexpr uint32_t raw_bit_field() const { return bit_field_; } constexpr uint32_t raw_bit_field() const { return bit_field_; }
......
...@@ -923,7 +923,7 @@ WASM_COMPILED_EXEC_TEST(FunctionRefs) { ...@@ -923,7 +923,7 @@ WASM_COMPILED_EXEC_TEST(FunctionRefs) {
tester.AddGlobal(ValueType::Ref(sig_index, kNullable), false, tester.AddGlobal(ValueType::Ref(sig_index, kNullable), false,
WasmInitExpr::RefFuncConst(func_index)); WasmInitExpr::RefFuncConst(func_index));
ValueType func_type = ValueType::Ref(sig_index, kNonNullable); ValueType func_type = ValueType::Ref(sig_index, kNullable);
FunctionSig sig_func(1, 0, &func_type); FunctionSig sig_func(1, 0, &func_type);
ValueType rtt0 = ValueType::Rtt(sig_index, 0); ValueType rtt0 = ValueType::Rtt(sig_index, 0);
...@@ -1003,14 +1003,13 @@ WASM_COMPILED_EXEC_TEST(RefTestCastNull) { ...@@ -1003,14 +1003,13 @@ WASM_COMPILED_EXEC_TEST(RefTestCastNull) {
kExprEnd}); kExprEnd});
const byte kRefCastNull = tester.DefineFunction( const byte kRefCastNull = tester.DefineFunction(
tester.sigs.i_i(), // Argument and return value ignored tester.sigs.i_v(), {},
{}, {WASM_REF_IS_NULL(WASM_REF_CAST(type_index, WASM_REF_NULL(type_index),
{WASM_REF_CAST(type_index, WASM_REF_NULL(type_index), WASM_RTT_CANON(type_index))),
WASM_RTT_CANON(type_index)), kExprEnd});
kExprDrop, WASM_I32V(0), kExprEnd});
tester.CompileModule(); tester.CompileModule();
tester.CheckResult(kRefTestNull, 0); tester.CheckResult(kRefTestNull, 0);
tester.CheckHasThrown(kRefCastNull, 0); tester.CheckResult(kRefCastNull, 1);
} }
WASM_COMPILED_EXEC_TEST(BasicI31) { WASM_COMPILED_EXEC_TEST(BasicI31) {
......
...@@ -4243,7 +4243,7 @@ TEST_F(FunctionBodyDecoderTest, RefTestCast) { ...@@ -4243,7 +4243,7 @@ TEST_F(FunctionBodyDecoderTest, RefTestCast) {
ValueType test_reps[] = {kWasmI32, ValueType::Ref(from_heap, kNullable)}; ValueType test_reps[] = {kWasmI32, ValueType::Ref(from_heap, kNullable)};
FunctionSig test_sig(1, 1, test_reps); FunctionSig test_sig(1, 1, test_reps);
ValueType cast_reps[] = {ValueType::Ref(to_heap, kNonNullable), ValueType cast_reps[] = {ValueType::Ref(to_heap, kNonNullable),
ValueType::Ref(from_heap, kNullable)}; ValueType::Ref(from_heap, kNonNullable)};
FunctionSig cast_sig(1, 1, cast_reps); FunctionSig cast_sig(1, 1, cast_reps);
ExpectValidates(&test_sig, ExpectValidates(&test_sig,
{WASM_REF_TEST(WASM_HEAP_TYPE(to_heap), WASM_LOCAL_GET(0), {WASM_REF_TEST(WASM_HEAP_TYPE(to_heap), WASM_LOCAL_GET(0),
...@@ -4263,7 +4263,7 @@ TEST_F(FunctionBodyDecoderTest, RefTestCast) { ...@@ -4263,7 +4263,7 @@ TEST_F(FunctionBodyDecoderTest, RefTestCast) {
HeapType to_heap = HeapType(pair.second); HeapType to_heap = HeapType(pair.second);
ValueType test_reps[] = {kWasmI32, ValueType::Ref(from_heap, kNullable)}; ValueType test_reps[] = {kWasmI32, ValueType::Ref(from_heap, kNullable)};
FunctionSig test_sig(1, 1, test_reps); FunctionSig test_sig(1, 1, test_reps);
ValueType cast_reps[] = {ValueType::Ref(to_heap, kNonNullable), ValueType cast_reps[] = {ValueType::Ref(to_heap, kNullable),
ValueType::Ref(from_heap, kNullable)}; ValueType::Ref(from_heap, kNullable)};
FunctionSig cast_sig(1, 1, cast_reps); FunctionSig cast_sig(1, 1, cast_reps);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment