Commit c03354b4 authored by Ng Zhi An's avatar Ng Zhi An Committed by V8 LUCI CQ

Reland "[wasm-simd][arm64] Fuse add and extmul"

This is a reland of 65515ddd

Fix is to use AddWithWraparound for signed additions to avoid UB.

Original change's description:
> [wasm-simd][arm64] Fuse add and extmul
>
> We can select a better instruction for add+extmul, using one of the
> multiply-long-accumulate instruction.
>
> Define a helper struct to pattern match Add(x, OP(y, z)) and
> Add(OP(x, y) z), and ensure that the matched OP is always on the
> LHS, to simplify checking for matches.
>
> Bug: v8:11548
> Change-Id: I7ab488b262aa9f749785f973549ccd9fad72f4c8
> Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/2826725
> Reviewed-by: Deepti Gandluri <gdeepti@chromium.org>
> Commit-Queue: Zhi An Ng <zhin@chromium.org>
> Cr-Commit-Position: refs/heads/main@{#76708}

Bug: v8:11548
Change-Id: I675ab8b78d9c6c30b82a8c96c8e7098a548c6a60
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3144379
Commit-Queue: Zhi An Ng <zhin@chromium.org>
Reviewed-by: 's avatarDeepti Gandluri <gdeepti@chromium.org>
Cr-Commit-Position: refs/heads/main@{#76712}
parent cea787e2
...@@ -1201,6 +1201,24 @@ CodeGenerator::CodeGenResult CodeGenerator::AssembleArchInstruction( ...@@ -1201,6 +1201,24 @@ CodeGenerator::CodeGenResult CodeGenerator::AssembleArchInstruction(
i.InputSimd128Register(0).Format(src_f), 0); i.InputSimd128Register(0).Format(src_f), 0);
break; break;
} }
case kArm64Smlal: {
VectorFormat dst_f = VectorFormatFillQ(LaneSizeField::decode(opcode));
VectorFormat src_f = VectorFormatHalfWidth(dst_f);
DCHECK_EQ(i.OutputSimd128Register(), i.InputSimd128Register(0));
__ Smlal(i.OutputSimd128Register().Format(dst_f),
i.InputSimd128Register(1).Format(src_f),
i.InputSimd128Register(2).Format(src_f));
break;
}
case kArm64Smlal2: {
VectorFormat dst_f = VectorFormatFillQ(LaneSizeField::decode(opcode));
VectorFormat src_f = VectorFormatHalfWidthDoubleLanes(dst_f);
DCHECK_EQ(i.OutputSimd128Register(), i.InputSimd128Register(0));
__ Smlal2(i.OutputSimd128Register().Format(dst_f),
i.InputSimd128Register(1).Format(src_f),
i.InputSimd128Register(2).Format(src_f));
break;
}
case kArm64Smull: { case kArm64Smull: {
if (instr->InputAt(0)->IsRegister()) { if (instr->InputAt(0)->IsRegister()) {
__ Smull(i.OutputRegister(), i.InputRegister32(0), __ Smull(i.OutputRegister(), i.InputRegister32(0),
...@@ -1223,6 +1241,24 @@ CodeGenerator::CodeGenResult CodeGenerator::AssembleArchInstruction( ...@@ -1223,6 +1241,24 @@ CodeGenerator::CodeGenResult CodeGenerator::AssembleArchInstruction(
i.InputSimd128Register(1).Format(src_f)); i.InputSimd128Register(1).Format(src_f));
break; break;
} }
case kArm64Umlal: {
VectorFormat dst_f = VectorFormatFillQ(LaneSizeField::decode(opcode));
VectorFormat src_f = VectorFormatHalfWidth(dst_f);
DCHECK_EQ(i.OutputSimd128Register(), i.InputSimd128Register(0));
__ Umlal(i.OutputSimd128Register().Format(dst_f),
i.InputSimd128Register(1).Format(src_f),
i.InputSimd128Register(2).Format(src_f));
break;
}
case kArm64Umlal2: {
VectorFormat dst_f = VectorFormatFillQ(LaneSizeField::decode(opcode));
VectorFormat src_f = VectorFormatHalfWidthDoubleLanes(dst_f);
DCHECK_EQ(i.OutputSimd128Register(), i.InputSimd128Register(0));
__ Umlal2(i.OutputSimd128Register().Format(dst_f),
i.InputSimd128Register(1).Format(src_f),
i.InputSimd128Register(2).Format(src_f));
break;
}
case kArm64Umull: { case kArm64Umull: {
if (instr->InputAt(0)->IsRegister()) { if (instr->InputAt(0)->IsRegister()) {
__ Umull(i.OutputRegister(), i.InputRegister32(0), __ Umull(i.OutputRegister(), i.InputRegister32(0),
......
...@@ -43,10 +43,14 @@ namespace compiler { ...@@ -43,10 +43,14 @@ namespace compiler {
V(Arm64Sub32) \ V(Arm64Sub32) \
V(Arm64Mul) \ V(Arm64Mul) \
V(Arm64Mul32) \ V(Arm64Mul32) \
V(Arm64Smlal) \
V(Arm64Smlal2) \
V(Arm64Smull) \ V(Arm64Smull) \
V(Arm64Smull2) \ V(Arm64Smull2) \
V(Arm64Uadalp) \ V(Arm64Uadalp) \
V(Arm64Uaddlp) \ V(Arm64Uaddlp) \
V(Arm64Umlal) \
V(Arm64Umlal2) \
V(Arm64Umull) \ V(Arm64Umull) \
V(Arm64Umull2) \ V(Arm64Umull2) \
V(Arm64Madd) \ V(Arm64Madd) \
......
...@@ -44,10 +44,14 @@ int InstructionScheduler::GetTargetInstructionFlags( ...@@ -44,10 +44,14 @@ int InstructionScheduler::GetTargetInstructionFlags(
case kArm64Sub32: case kArm64Sub32:
case kArm64Mul: case kArm64Mul:
case kArm64Mul32: case kArm64Mul32:
case kArm64Smlal:
case kArm64Smlal2:
case kArm64Smull: case kArm64Smull:
case kArm64Smull2: case kArm64Smull2:
case kArm64Uadalp: case kArm64Uadalp:
case kArm64Uaddlp: case kArm64Uaddlp:
case kArm64Umlal:
case kArm64Umlal2:
case kArm64Umull: case kArm64Umull:
case kArm64Umull2: case kArm64Umull2:
case kArm64Madd: case kArm64Madd:
......
...@@ -3779,28 +3779,53 @@ void InstructionSelector::VisitI64x2Mul(Node* node) { ...@@ -3779,28 +3779,53 @@ void InstructionSelector::VisitI64x2Mul(Node* node) {
namespace { namespace {
// Used for pattern matching SIMD Add operations where one of the inputs matches
// |opcode| and ensure that the matched input is on the LHS (input 0).
struct SimdAddOpMatcher : public NodeMatcher {
explicit SimdAddOpMatcher(Node* node, IrOpcode::Value opcode)
: NodeMatcher(node),
opcode_(opcode),
left_(InputAt(0)),
right_(InputAt(1)) {
DCHECK(HasProperty(Operator::kCommutative));
PutOpOnLeft();
}
bool Matches() { return left_->opcode() == opcode_; }
Node* left() const { return left_; }
Node* right() const { return right_; }
private:
void PutOpOnLeft() {
if (right_->opcode() == opcode_) {
std::swap(left_, right_);
node()->ReplaceInput(0, left_);
node()->ReplaceInput(1, right_);
}
}
IrOpcode::Value opcode_;
Node* left_;
Node* right_;
};
bool ShraHelper(InstructionSelector* selector, Node* node, int lane_size, bool ShraHelper(InstructionSelector* selector, Node* node, int lane_size,
InstructionCode shra_code, InstructionCode add_code, InstructionCode shra_code, InstructionCode add_code,
IrOpcode::Value shift_op) { IrOpcode::Value shift_op) {
Arm64OperandGenerator g(selector); Arm64OperandGenerator g(selector);
Node* left = node->InputAt(0); SimdAddOpMatcher m(node, shift_op);
Node* right = node->InputAt(1); if (!m.Matches() || !selector->CanCover(node, m.left())) return false;
if (right->opcode() == shift_op) { if (!g.IsIntegerConstant(m.left()->InputAt(1))) return false;
std::swap(left, right);
} else if (left->opcode() != shift_op) {
return false;
}
if (!selector->CanCover(node, left) || !g.IsIntegerConstant(left->InputAt(1)))
return false;
// If shifting by zero, just do the addition // If shifting by zero, just do the addition
if (g.GetIntegerConstantValue(left->InputAt(1)) % lane_size == 0) { if (g.GetIntegerConstantValue(m.left()->InputAt(1)) % lane_size == 0) {
selector->Emit(add_code, g.DefineAsRegister(node), selector->Emit(add_code, g.DefineAsRegister(node),
g.UseRegister(left->InputAt(0)), g.UseRegister(right)); g.UseRegister(m.left()->InputAt(0)),
g.UseRegister(m.right()));
} else { } else {
selector->Emit(shra_code | LaneSizeField::encode(lane_size), selector->Emit(shra_code | LaneSizeField::encode(lane_size),
g.DefineSameAsFirst(node), g.UseRegister(right), g.DefineSameAsFirst(node), g.UseRegister(m.right()),
g.UseRegister(left->InputAt(0)), g.UseRegister(m.left()->InputAt(0)),
g.UseImmediate(left->InputAt(1))); g.UseImmediate(m.left()->InputAt(1)));
} }
return true; return true;
} }
...@@ -3808,39 +3833,36 @@ bool ShraHelper(InstructionSelector* selector, Node* node, int lane_size, ...@@ -3808,39 +3833,36 @@ bool ShraHelper(InstructionSelector* selector, Node* node, int lane_size,
bool AdalpHelper(InstructionSelector* selector, Node* node, int lane_size, bool AdalpHelper(InstructionSelector* selector, Node* node, int lane_size,
InstructionCode adalp_code, IrOpcode::Value ext_op) { InstructionCode adalp_code, IrOpcode::Value ext_op) {
Arm64OperandGenerator g(selector); Arm64OperandGenerator g(selector);
Node* left = node->InputAt(0); SimdAddOpMatcher m(node, ext_op);
Node* right = node->InputAt(1); if (!m.Matches() || !selector->CanCover(node, m.left())) return false;
if (right->opcode() == ext_op) { selector->Emit(adalp_code | LaneSizeField::encode(lane_size),
std::swap(left, right); g.DefineSameAsFirst(node), g.UseRegister(m.right()),
} else if (left->opcode() != ext_op) { g.UseRegister(m.left()->InputAt(0)));
return false; return true;
}
if (selector->CanCover(node, left)) {
selector->Emit(adalp_code | LaneSizeField::encode(lane_size),
g.DefineSameAsFirst(node), g.UseRegister(right),
g.UseRegister(left->InputAt(0)));
return true;
}
return false;
} }
bool MlaHelper(InstructionSelector* selector, Node* node, bool MlaHelper(InstructionSelector* selector, Node* node,
InstructionCode mla_code, IrOpcode::Value mul_op) { InstructionCode mla_code, IrOpcode::Value mul_op) {
Arm64OperandGenerator g(selector); Arm64OperandGenerator g(selector);
Node* left = node->InputAt(0); SimdAddOpMatcher m(node, mul_op);
Node* right = node->InputAt(1); if (!m.Matches() || !selector->CanCover(node, m.left())) return false;
if (right->opcode() == mul_op) { selector->Emit(mla_code, g.DefineSameAsFirst(node), g.UseRegister(m.right()),
std::swap(left, right); g.UseRegister(m.left()->InputAt(0)),
} else if (left->opcode() != mul_op) { g.UseRegister(m.left()->InputAt(1)));
return false; return true;
} }
if (selector->CanCover(node, left)) {
selector->Emit(mla_code, g.DefineSameAsFirst(node), g.UseRegister(right), bool SmlalHelper(InstructionSelector* selector, Node* node, int lane_size,
g.UseRegister(left->InputAt(0)), InstructionCode smlal_code, IrOpcode::Value ext_mul_op) {
g.UseRegister(left->InputAt(1))); Arm64OperandGenerator g(selector);
return true; SimdAddOpMatcher m(node, ext_mul_op);
} if (!m.Matches() || !selector->CanCover(node, m.left())) return false;
return false;
selector->Emit(smlal_code | LaneSizeField::encode(lane_size),
g.DefineSameAsFirst(node), g.UseRegister(m.right()),
g.UseRegister(m.left()->InputAt(0)),
g.UseRegister(m.left()->InputAt(1)));
return true;
} }
} // namespace } // namespace
...@@ -3890,6 +3912,18 @@ void InstructionSelector::VisitI8x16Add(Node* node) { ...@@ -3890,6 +3912,18 @@ void InstructionSelector::VisitI8x16Add(Node* node) {
IrOpcode::k##Type##ShrU)) { \ IrOpcode::k##Type##ShrU)) { \
return; \ return; \
} \ } \
/* Select Smlal/Umlal(x, y, z) for Add(x, ExtMulLow(y, z)) and \
* Smlal2/Umlal2(x, y, z) for Add(x, ExtMulHigh(y, z)). */ \
if (SmlalHelper(this, node, LaneSize, kArm64Smlal, \
IrOpcode::k##Type##ExtMulLow##PairwiseType##S) || \
SmlalHelper(this, node, LaneSize, kArm64Smlal2, \
IrOpcode::k##Type##ExtMulHigh##PairwiseType##S) || \
SmlalHelper(this, node, LaneSize, kArm64Umlal, \
IrOpcode::k##Type##ExtMulLow##PairwiseType##U) || \
SmlalHelper(this, node, LaneSize, kArm64Umlal2, \
IrOpcode::k##Type##ExtMulHigh##PairwiseType##U)) { \
return; \
} \
VisitRRR(this, kArm64IAdd | LaneSizeField::encode(LaneSize), node); \ VisitRRR(this, kArm64IAdd | LaneSizeField::encode(LaneSize), node); \
} }
......
...@@ -1685,6 +1685,77 @@ WASM_SIMD_TEST(I64x2ExtMulHighI32x4U) { ...@@ -1685,6 +1685,77 @@ WASM_SIMD_TEST(I64x2ExtMulHighI32x4U) {
MulHalf::kHigh); MulHalf::kHigh);
} }
namespace {
// Test add(mul(x, y, z) optimizations.
template <typename S, typename T>
void RunExtMulAddOptimizationTest(TestExecutionTier execution_tier,
WasmOpcode ext_mul, WasmOpcode narrow_splat,
WasmOpcode wide_splat, WasmOpcode wide_add,
std::function<T(T, T)> addop) {
WasmRunner<int32_t, S, T> r(execution_tier);
T* g = r.builder().template AddGlobal<T>(kWasmS128);
// global[0] =
// add(
// splat(local[1]),
// extmul(splat(local[0]), splat(local[0])))
BUILD(r,
WASM_GLOBAL_SET(
0, WASM_SIMD_BINOP(
wide_add, WASM_SIMD_UNOP(wide_splat, WASM_LOCAL_GET(1)),
WASM_SIMD_BINOP(
ext_mul, WASM_SIMD_UNOP(narrow_splat, WASM_LOCAL_GET(0)),
WASM_SIMD_UNOP(narrow_splat, WASM_LOCAL_GET(0))))),
WASM_ONE);
constexpr int lanes = kSimd128Size / sizeof(T);
for (S x : compiler::ValueHelper::GetVector<S>()) {
for (T y : compiler::ValueHelper::GetVector<T>()) {
r.Call(x, y);
T expected = addop(MultiplyLong<T, S>(x, x), y);
for (int i = 0; i < lanes; i++) {
CHECK_EQ(expected, ReadLittleEndianValue<T>(&g[i]));
}
}
}
}
} // namespace
// Helper which defines high/low, signed/unsigned test cases for extmul + add
// optimization.
#define EXTMUL_ADD_OPTIMIZATION_TEST(NarrowType, NarrowShape, WideType, \
WideShape) \
WASM_SIMD_TEST(WideShape##ExtMulLow##NarrowShape##SAddOptimization) { \
RunExtMulAddOptimizationTest<NarrowType, WideType>( \
execution_tier, kExpr##WideShape##ExtMulLow##NarrowShape##S, \
kExpr##NarrowShape##Splat, kExpr##WideShape##Splat, \
kExpr##WideShape##Add, base::AddWithWraparound<WideType>); \
} \
WASM_SIMD_TEST(WideShape##ExtMulHigh##NarrowShape##SAddOptimization) { \
RunExtMulAddOptimizationTest<NarrowType, WideType>( \
execution_tier, kExpr##WideShape##ExtMulHigh##NarrowShape##S, \
kExpr##NarrowShape##Splat, kExpr##WideShape##Splat, \
kExpr##WideShape##Add, base::AddWithWraparound<WideType>); \
} \
WASM_SIMD_TEST(WideShape##ExtMulLow##NarrowShape##UAddOptimization) { \
RunExtMulAddOptimizationTest<u##NarrowType, u##WideType>( \
execution_tier, kExpr##WideShape##ExtMulLow##NarrowShape##U, \
kExpr##NarrowShape##Splat, kExpr##WideShape##Splat, \
kExpr##WideShape##Add, std::plus<u##WideType>()); \
} \
WASM_SIMD_TEST(WideShape##ExtMulHigh##NarrowShape##UAddOptimization) { \
RunExtMulAddOptimizationTest<u##NarrowType, u##WideType>( \
execution_tier, kExpr##WideShape##ExtMulHigh##NarrowShape##U, \
kExpr##NarrowShape##Splat, kExpr##WideShape##Splat, \
kExpr##WideShape##Add, std::plus<u##WideType>()); \
}
EXTMUL_ADD_OPTIMIZATION_TEST(int8_t, I8x16, int16_t, I16x8)
EXTMUL_ADD_OPTIMIZATION_TEST(int16_t, I16x8, int32_t, I32x4)
#undef EXTMUL_ADD_OPTIMIZATION_TEST
WASM_SIMD_TEST(I32x4DotI16x8S) { WASM_SIMD_TEST(I32x4DotI16x8S) {
WasmRunner<int32_t, int16_t, int16_t> r(execution_tier); WasmRunner<int32_t, int16_t, int16_t> r(execution_tier);
int32_t* g = r.builder().template AddGlobal<int32_t>(kWasmS128); int32_t* g = r.builder().template AddGlobal<int32_t>(kWasmS128);
......
...@@ -2328,6 +2328,83 @@ INSTANTIATE_TEST_SUITE_P(InstructionSelectorTest, ...@@ -2328,6 +2328,83 @@ INSTANTIATE_TEST_SUITE_P(InstructionSelectorTest,
InstructionSelectorSIMDShrAddTest, InstructionSelectorSIMDShrAddTest,
::testing::ValuesIn(kSIMDShrAddInstructions)); ::testing::ValuesIn(kSIMDShrAddInstructions));
namespace {
struct SIMDAddExtMulInst {
const char* mul_constructor_name;
const Operator* (MachineOperatorBuilder::*mul_operator)();
const Operator* (MachineOperatorBuilder::*add_operator)();
ArchOpcode multiply_add_arch_opcode;
MachineType machine_type;
int lane_size;
};
} // namespace
static const SIMDAddExtMulInst kSimdAddExtMulInstructions[] = {
{"I16x8ExtMulLowI8x16S", &MachineOperatorBuilder::I16x8ExtMulLowI8x16S,
&MachineOperatorBuilder::I16x8Add, kArm64Smlal, MachineType::Simd128(),
16},
{"I16x8ExtMulHighI8x16S", &MachineOperatorBuilder::I16x8ExtMulHighI8x16S,
&MachineOperatorBuilder::I16x8Add, kArm64Smlal2, MachineType::Simd128(),
16},
{"I16x8ExtMulLowI8x16U", &MachineOperatorBuilder::I16x8ExtMulLowI8x16U,
&MachineOperatorBuilder::I16x8Add, kArm64Umlal, MachineType::Simd128(),
16},
{"I16x8ExtMulHighI8x16U", &MachineOperatorBuilder::I16x8ExtMulHighI8x16U,
&MachineOperatorBuilder::I16x8Add, kArm64Umlal2, MachineType::Simd128(),
16},
{"I32x4ExtMulLowI16x8S", &MachineOperatorBuilder::I32x4ExtMulLowI16x8S,
&MachineOperatorBuilder::I32x4Add, kArm64Smlal, MachineType::Simd128(),
32},
{"I32x4ExtMulHighI16x8S", &MachineOperatorBuilder::I32x4ExtMulHighI16x8S,
&MachineOperatorBuilder::I32x4Add, kArm64Smlal2, MachineType::Simd128(),
32},
{"I32x4ExtMulLowI16x8U", &MachineOperatorBuilder::I32x4ExtMulLowI16x8U,
&MachineOperatorBuilder::I32x4Add, kArm64Umlal, MachineType::Simd128(),
32},
{"I32x4ExtMulHighI16x8U", &MachineOperatorBuilder::I32x4ExtMulHighI16x8U,
&MachineOperatorBuilder::I32x4Add, kArm64Umlal2, MachineType::Simd128(),
32}};
using InstructionSelectorSIMDAddExtMulTest =
InstructionSelectorTestWithParam<SIMDAddExtMulInst>;
// TODO(zhin): This can be merged with InstructionSelectorSIMDDPWithSIMDMulTest
// once sub+extmul matching is implemented.
TEST_P(InstructionSelectorSIMDAddExtMulTest, AddExtMul) {
const SIMDAddExtMulInst mdpi = GetParam();
const MachineType type = mdpi.machine_type;
{
// Test Add(x, ExtMul(y, z)).
StreamBuilder m(this, type, type, type, type);
Node* n = m.AddNode((m.machine()->*mdpi.mul_operator)(), m.Parameter(1),
m.Parameter(2));
m.Return(m.AddNode((m.machine()->*mdpi.add_operator)(), m.Parameter(0), n));
Stream s = m.Build();
ASSERT_EQ(1U, s.size());
EXPECT_EQ(mdpi.multiply_add_arch_opcode, s[0]->arch_opcode());
EXPECT_EQ(mdpi.lane_size, LaneSizeField::decode(s[0]->opcode()));
EXPECT_EQ(3U, s[0]->InputCount());
EXPECT_EQ(1U, s[0]->OutputCount());
}
{
// Test Add(ExtMul(y, z), x), making sure it's commutative.
StreamBuilder m(this, type, type, type, type);
Node* n = m.AddNode((m.machine()->*mdpi.mul_operator)(), m.Parameter(0),
m.Parameter(1));
m.Return(m.AddNode((m.machine()->*mdpi.add_operator)(), n, m.Parameter(2)));
Stream s = m.Build();
ASSERT_EQ(1U, s.size());
EXPECT_EQ(mdpi.multiply_add_arch_opcode, s[0]->arch_opcode());
EXPECT_EQ(mdpi.lane_size, LaneSizeField::decode(s[0]->opcode()));
EXPECT_EQ(3U, s[0]->InputCount());
EXPECT_EQ(1U, s[0]->OutputCount());
}
}
INSTANTIATE_TEST_SUITE_P(InstructionSelectorTest,
InstructionSelectorSIMDAddExtMulTest,
::testing::ValuesIn(kSimdAddExtMulInstructions));
struct SIMDMulDupInst { struct SIMDMulDupInst {
const uint8_t shuffle[16]; const uint8_t shuffle[16];
int32_t lane; int32_t lane;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment