Skip to content

Commit

Permalink
[x86] Generate AVX512 fixed-point instructions (#7129)
Browse files Browse the repository at this point in the history
* clean-up abs and saturating_pmulhrs, fix AVX512 saturating_ ops

* add test coverage for AVX512 fp ops

* generate vpabs on AVX512

* faster AVX2 lowering of saturating_pmulhrs
  • Loading branch information
rootjalex committed Oct 31, 2022
1 parent bad945f commit 5da5dfd
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 87 deletions.
54 changes: 51 additions & 3 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Target complete_x86_target(Target t) {
if (t.has_feature(Target::AVX512_Cannonlake) ||
t.has_feature(Target::AVX512_Skylake) ||
t.has_feature(Target::AVX512_KNL)) {
t.set_feature(Target::AVX512);
}
if (t.has_feature(Target::AVX512)) {
t.set_feature(Target::AVX2);
}
if (t.has_feature(Target::AVX2)) {
Expand Down Expand Up @@ -111,6 +114,12 @@ struct x86Intrinsic {

// clang-format off
const x86Intrinsic intrinsic_defs[] = {
// AVX2/SSSE3 LLVM intrinsics for pabs fail in JIT. The integer wrappers
// just call `llvm.abs` (which requires a second argument).
// AVX512BW's pabs instructions aren't directly exposed by LLVM.
{"abs_i8x64", UInt(8, 64), "abs", {Int(8, 64)}, Target::AVX512_Skylake},
{"abs_i16x32", UInt(16, 32), "abs", {Int(16, 32)}, Target::AVX512_Skylake},
{"abs_i32x16", UInt(32, 16), "abs", {Int(32, 16)}, Target::AVX512_Skylake},
{"abs_i8x32", UInt(8, 32), "abs", {Int(8, 32)}, Target::AVX2},
{"abs_i16x16", UInt(16, 16), "abs", {Int(16, 16)}, Target::AVX2},
{"abs_i32x8", UInt(32, 8), "abs", {Int(32, 8)}, Target::AVX2},
Expand All @@ -125,15 +134,19 @@ const x86Intrinsic intrinsic_defs[] = {
{"round_f32x8", Float(32, 8), "round", {Float(32, 8)}, Target::AVX},
{"round_f64x4", Float(64, 4), "round", {Float(64, 4)}, Target::AVX},

{"llvm.sadd.sat.v64i8", Int(8, 64), "saturating_add", {Int(8, 64), Int(8, 64)}, Target::AVX512_Skylake},
{"llvm.sadd.sat.v32i8", Int(8, 32), "saturating_add", {Int(8, 32), Int(8, 32)}, Target::AVX2},
{"llvm.sadd.sat.v16i8", Int(8, 16), "saturating_add", {Int(8, 16), Int(8, 16)}},
{"llvm.sadd.sat.v8i8", Int(8, 8), "saturating_add", {Int(8, 8), Int(8, 8)}},
{"llvm.ssub.sat.v64i8", Int(8, 64), "saturating_sub", {Int(8, 64), Int(8, 64)}, Target::AVX512_Skylake},
{"llvm.ssub.sat.v32i8", Int(8, 32), "saturating_sub", {Int(8, 32), Int(8, 32)}, Target::AVX2},
{"llvm.ssub.sat.v16i8", Int(8, 16), "saturating_sub", {Int(8, 16), Int(8, 16)}},
{"llvm.ssub.sat.v8i8", Int(8, 8), "saturating_sub", {Int(8, 8), Int(8, 8)}},

{"llvm.sadd.sat.v32i16", Int(16, 32), "saturating_add", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.sadd.sat.v16i16", Int(16, 16), "saturating_add", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.sadd.sat.v8i16", Int(16, 8), "saturating_add", {Int(16, 8), Int(16, 8)}},
{"llvm.ssub.sat.v32i16", Int(16, 32), "saturating_sub", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.ssub.sat.v16i16", Int(16, 16), "saturating_sub", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.ssub.sat.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}},

Expand All @@ -149,13 +162,17 @@ const x86Intrinsic intrinsic_defs[] = {
// Target::AVX instead of Target::AVX2 as the feature flag
// requirement.
// TODO: Just use llvm.*add/*sub.sat, and verify the above comment?
{"llvm.uadd.sat.v64i8", UInt(8, 64), "saturating_add", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake},
{"paddusbx32", UInt(8, 32), "saturating_add", {UInt(8, 32), UInt(8, 32)}, Target::AVX},
{"paddusbx16", UInt(8, 16), "saturating_add", {UInt(8, 16), UInt(8, 16)}},
{"llvm.usub.sat.v64i8", UInt(8, 64), "saturating_sub", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake},
{"psubusbx32", UInt(8, 32), "saturating_sub", {UInt(8, 32), UInt(8, 32)}, Target::AVX},
{"psubusbx16", UInt(8, 16), "saturating_sub", {UInt(8, 16), UInt(8, 16)}},

{"llvm.uadd.sat.v32i16", UInt(16, 32), "saturating_add", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"padduswx16", UInt(16, 16), "saturating_add", {UInt(16, 16), UInt(16, 16)}, Target::AVX},
{"padduswx8", UInt(16, 8), "saturating_add", {UInt(16, 8), UInt(16, 8)}},
{"llvm.usub.sat.v32i16", UInt(16, 32), "saturating_sub", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"psubuswx16", UInt(16, 16), "saturating_sub", {UInt(16, 16), UInt(16, 16)}, Target::AVX},
{"psubuswx8", UInt(16, 8), "saturating_sub", {UInt(16, 8), UInt(16, 8)}},

Expand All @@ -180,14 +197,15 @@ const x86Intrinsic intrinsic_defs[] = {
{"wmul_pmaddwd_sse2", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}},

// Multiply keep high half
{"llvm.x86.avx512.pmulh.w.512", Int(16, 32), "pmulh", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmulh.w", Int(16, 16), "pmulh", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.x86.avx512.pmulhu.w.512", UInt(16, 32), "pmulh", {UInt(16, 32), UInt(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmulhu.w", UInt(16, 16), "pmulh", {UInt(16, 16), UInt(16, 16)}, Target::AVX2},
{"llvm.x86.avx512.pmul.hr.sw.512", Int(16, 32), "pmulhrs", {Int(16, 32), Int(16, 32)}, Target::AVX512_Skylake},
{"llvm.x86.avx2.pmul.hr.sw", Int(16, 16), "pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"saturating_pmulhrswx16", Int(16, 16), "saturating_pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2},
{"llvm.x86.sse2.pmulh.w", Int(16, 8), "pmulh", {Int(16, 8), Int(16, 8)}},
{"llvm.x86.sse2.pmulhu.w", UInt(16, 8), "pmulh", {UInt(16, 8), UInt(16, 8)}},
{"llvm.x86.ssse3.pmul.hr.sw.128", Int(16, 8), "pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},
{"saturating_pmulhrswx8", Int(16, 8), "saturating_pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41},

// Convert FP32 to BF16
{"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids},
Expand Down Expand Up @@ -582,7 +600,6 @@ void CodeGen_X86::visit(const Call *op) {
static Pattern patterns[] = {
{"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)},
{"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)},
{"saturating_pmulhrs", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15)},
{"saturating_narrow", i16_sat(wild_i32x_)},
{"saturating_narrow", u16_sat(wild_i32x_)},
{"saturating_narrow", i8_sat(wild_i16x_)},
Expand All @@ -600,6 +617,37 @@ void CodeGen_X86::visit(const Call *op) {
}
}

// Check for saturating_pmulhrs. On x86, pmulhrs is truncating, but it's still faster
// to use pmulhrs than to lower (producing widening multiplication), and have a check
// for the singular overflow case.
static Expr saturating_pmulhrs = rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15);
if (expr_match(saturating_pmulhrs, op, matches)) {
// Rewrite so that we can take advantage of pmulhrs.
internal_assert(matches.size() == 2);
internal_assert(op->type.element_of() == Int(16));
const Expr &a = matches[0];
const Expr &b = matches[1];

Expr pmulhrs = i16(rounding_shift_right(widening_mul(a, b), 15));

Expr i16_min = op->type.min();
Expr i16_max = op->type.max();

// Handle edge case of possible overflow.
// See https://github.com/halide/Halide/pull/7129/files#r1008331426
// On AVX512 (and with enough lanes) we can use a mask register.
if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs);
expr.accept(this);
} else {
Expr mask = select(max(a, b) == i16_min, cast(op->type, -1), cast(op->type, 0));
Expr expr = mask ^ pmulhrs;
expr.accept(this);
}

return;
}

CodeGen_Posix::visit(op);
}

Expand Down
31 changes: 9 additions & 22 deletions src/runtime/x86_avx2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,22 @@ define weak_odr <16 x i16> @packusdwx16(<16 x i32> %arg) nounwind alwaysinline
declare <16 x i16> @llvm.x86.avx2.packusdw(<8 x i32>, <8 x i32>) nounwind readnone

define weak_odr <32 x i8> @abs_i8x32(<32 x i8> %arg) {
%1 = sub <32 x i8> zeroinitializer, %arg
%2 = icmp sgt <32 x i8> %arg, zeroinitializer
%3 = select <32 x i1> %2, <32 x i8> %arg, <32 x i8> %1
ret <32 x i8> %3
%1 = tail call <32 x i8> @llvm.abs.v32i8(<32 x i8> %arg, i1 false)
ret <32 x i8> %1
}
declare <32 x i8> @llvm.abs.v32i8(<32 x i8>, i1) nounwind readnone

define weak_odr <16 x i16> @abs_i16x16(<16 x i16> %arg) {
%1 = sub <16 x i16> zeroinitializer, %arg
%2 = icmp sgt <16 x i16> %arg, zeroinitializer
%3 = select <16 x i1> %2, <16 x i16> %arg, <16 x i16> %1
ret <16 x i16> %3
%1 = tail call <16 x i16> @llvm.abs.v16i16(<16 x i16> %arg, i1 false)
ret <16 x i16> %1
}
declare <16 x i16> @llvm.abs.v16i16(<16 x i16>, i1) nounwind readnone

define weak_odr <8 x i32> @abs_i32x8(<8 x i32> %arg) {
%1 = sub <8 x i32> zeroinitializer, %arg
%2 = icmp sgt <8 x i32> %arg, zeroinitializer
%3 = select <8 x i1> %2, <8 x i32> %arg, <8 x i32> %1
ret <8 x i32> %3
}

define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b) nounwind uwtable readnone alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> %a, <16 x i16> %b)
%2 = icmp eq <16 x i16> %a, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%3 = icmp eq <16 x i16> %b, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%4 = and <16 x i1> %2, %3
%5 = select <16 x i1> %4, <16 x i16> <i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767>, <16 x i16> %1
ret <16 x i16> %5
%1 = tail call <8 x i32> @llvm.abs.v8i32(<8 x i32> %arg, i1 false)
ret <8 x i32> %1
}
declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone
declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1) nounwind readnone

define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline {
%1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
Expand Down
18 changes: 18 additions & 0 deletions src/runtime/x86_avx512.ll
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,21 @@ define weak_odr <4 x i32> @dpwssdsx4(<4 x i32> %init, <8 x i16> %a, <8 x i16> %
ret <4 x i32> %3
}
declare <4 x i32> @llvm.x86.avx512.vpdpwssds.128(<4 x i32>, <4 x i32>, <4 x i32>)

define weak_odr <64 x i8> @abs_i8x64(<64 x i8> %arg) {
%1 = tail call <64 x i8> @llvm.abs.v64i8(<64 x i8> %arg, i1 false)
ret <64 x i8> %1
}
declare <64 x i8> @llvm.abs.v64i8(<64 x i8>, i1) nounwind readnone

define weak_odr <32 x i16> @abs_i16x32(<32 x i16> %arg) {
%1 = tail call <32 x i16> @llvm.abs.v32i16(<32 x i16> %arg, i1 false)
ret <32 x i16> %1
}
declare <32 x i16> @llvm.abs.v32i16(<32 x i16>, i1) nounwind readnone

define weak_odr <16 x i32> @abs_i32x16(<16 x i32> %arg) {
%1 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %arg, i1 false)
ret <16 x i32> %1
}
declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1) nounwind readnone
31 changes: 9 additions & 22 deletions src/runtime/x86_sse41.ll
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,22 @@ define weak_odr <2 x double> @trunc_f64x2(<2 x double> %x) nounwind uwtable read
}

define weak_odr <16 x i8> @abs_i8x16(<16 x i8> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <16 x i8> zeroinitializer, %x
%2 = icmp sgt <16 x i8> %x, zeroinitializer
%3 = select <16 x i1> %2, <16 x i8> %x, <16 x i8> %1
ret <16 x i8> %3
%1 = tail call <16 x i8> @llvm.abs.v16i8(<16 x i8> %x, i1 false)
ret <16 x i8> %1
}
declare <16 x i8> @llvm.abs.v16i8(<16 x i8>, i1) nounwind readnone

define weak_odr <8 x i16> @abs_i16x8(<8 x i16> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <8 x i16> zeroinitializer, %x
%2 = icmp sgt <8 x i16> %x, zeroinitializer
%3 = select <8 x i1> %2, <8 x i16> %x, <8 x i16> %1
ret <8 x i16> %3
%1 = tail call <8 x i16> @llvm.abs.v8i16(<8 x i16> %x, i1 false)
ret <8 x i16> %1
}
declare <8 x i16> @llvm.abs.v8i16(<8 x i16>, i1) nounwind readnone

define weak_odr <4 x i32> @abs_i32x4(<4 x i32> %x) nounwind uwtable readnone alwaysinline {
%1 = sub <4 x i32> zeroinitializer, %x
%2 = icmp sgt <4 x i32> %x, zeroinitializer
%3 = select <4 x i1> %2, <4 x i32> %x, <4 x i32> %1
ret <4 x i32> %3
}

define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nounwind uwtable readnone alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> %a, <8 x i16> %b)
%2 = icmp eq <8 x i16> %a, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%3 = icmp eq <8 x i16> %b, <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>
%4 = and <8 x i1> %2, %3
%5 = select <8 x i1> %4, <8 x i16> <i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767, i16 32767>, <8 x i16> %1
ret <8 x i16> %5
%1 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %x, i1 false)
ret <4 x i32> %1
}
declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1) nounwind readnone

define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline {
%1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
Expand Down
88 changes: 48 additions & 40 deletions test/correctness/simd_op_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,51 +448,59 @@ class SimdOpCheck : public SimdOpCheckTest {
// AVX 2

if (use_avx2) {
check("vpaddb*ymm", 32, u8_1 + u8_2);
check("vpsubb*ymm", 32, u8_1 - u8_2);
check("vpaddsb*ymm", 32, i8_sat(i16(i8_1) + i16(i8_2)));
check("vpsubsb*ymm", 32, i8_sat(i16(i8_1) - i16(i8_2)));
check("vpaddusb*ymm", 32, u8(min(u16(u8_1) + u16(u8_2), max_u8)));
check("vpsubusb*ymm", 32, u8(max(i16(u8_1) - i16(u8_2), 0)));
check("vpaddw*ymm", 16, u16_1 + u16_2);
check("vpsubw*ymm", 16, u16_1 - u16_2);
check("vpaddsw*ymm", 16, i16_sat(i32(i16_1) + i32(i16_2)));
check("vpsubsw*ymm", 16, i16_sat(i32(i16_1) - i32(i16_2)));
check("vpaddusw*ymm", 16, u16(min(u32(u16_1) + u32(u16_2), max_u16)));
check("vpsubusw*ymm", 16, u16(max(i32(u16_1) - i32(u16_2), 0)));
check("vpaddd*ymm", 8, i32_1 + i32_2);
check("vpsubd*ymm", 8, i32_1 - i32_2);
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) / (256 * 256)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) >> cast<unsigned>(16)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) >> cast<int>(16)));
check("vpmulhw*ymm", 16, i16((i32(i16_1) * i32(i16_2)) << cast<int>(-16)));
check("vpmullw*ymm", 16, i16_1 * i16_2);

check("vpmulhrsw*ymm", 16, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));
check("vpmulhrsw*ymm", 16, i16_sat((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));

check("vpcmp*b*ymm", 32, select(u8_1 == u8_2, u8(1), u8(2)));
check("vpcmp*b*ymm", 32, select(u8_1 > u8_2, u8(1), u8(2)));
check("vpcmp*w*ymm", 16, select(u16_1 == u16_2, u16(1), u16(2)));
check("vpcmp*w*ymm", 16, select(u16_1 > u16_2, u16(1), u16(2)));
check("vpcmp*d*ymm", 8, select(u32_1 == u32_2, u32(1), u32(2)));
check("vpcmp*d*ymm", 8, select(u32_1 > u32_2, u32(1), u32(2)));

check("vpavgb*ymm", 32, u8((u16(u8_1) + u16(u8_2) + 1) / 2));
check("vpavgw*ymm", 16, u16((u32(u16_1) + u32(u16_2) + 1) / 2));
check("vpmaxsw*ymm", 16, max(i16_1, i16_2));
check("vpminsw*ymm", 16, min(i16_1, i16_2));
check("vpmaxub*ymm", 32, max(u8_1, u8_2));
check("vpminub*ymm", 32, min(u8_1, u8_2));
auto check_x86_fixed_point = [&](const std::string &suffix, const int m) {
check("vpaddb*" + suffix, 32 * m, u8_1 + u8_2);
check("vpsubb*" + suffix, 32 * m, u8_1 - u8_2);
check("vpaddsb*" + suffix, 32 * m, i8_sat(i16(i8_1) + i16(i8_2)));
check("vpsubsb*" + suffix, 32 * m, i8_sat(i16(i8_1) - i16(i8_2)));
check("vpaddusb*" + suffix, 32 * m, u8(min(u16(u8_1) + u16(u8_2), max_u8)));
check("vpsubusb*" + suffix, 32 * m, u8(max(i16(u8_1) - i16(u8_2), 0)));
check("vpaddw*" + suffix, 16 * m, u16_1 + u16_2);
check("vpsubw*" + suffix, 16 * m, u16_1 - u16_2);
check("vpaddsw*" + suffix, 16 * m, i16_sat(i32(i16_1) + i32(i16_2)));
check("vpsubsw*" + suffix, 16 * m, i16_sat(i32(i16_1) - i32(i16_2)));
check("vpaddusw*" + suffix, 16 * m, u16(min(u32(u16_1) + u32(u16_2), max_u16)));
check("vpsubusw*" + suffix, 16 * m, u16(max(i32(u16_1) - i32(u16_2), 0)));
check("vpaddd*" + suffix, 8 * m, i32_1 + i32_2);
check("vpsubd*" + suffix, 8 * m, i32_1 - i32_2);
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) / (256 * 256)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) >> cast<unsigned>(16)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) >> cast<int>(16)));
check("vpmulhw*" + suffix, 16 * m, i16((i32(i16_1) * i32(i16_2)) << cast<int>(-16)));
check("vpmullw*" + suffix, 16 * m, i16_1 * i16_2);

check("vpmulhrsw*" + suffix, 16 * m, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));
check("vpmulhrsw*" + suffix, 16 * m, i16_sat((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768));

check("vpcmp*b*" + suffix, 32 * m, select(u8_1 == u8_2, u8(1), u8(2)));
check("vpcmp*b*" + suffix, 32 * m, select(u8_1 > u8_2, u8(1), u8(2)));
check("vpcmp*w*" + suffix, 16 * m, select(u16_1 == u16_2, u16(1), u16(2)));
check("vpcmp*w*" + suffix, 16 * m, select(u16_1 > u16_2, u16(1), u16(2)));
check("vpcmp*d*" + suffix, 8 * m, select(u32_1 == u32_2, u32(1), u32(2)));
check("vpcmp*d*" + suffix, 8 * m, select(u32_1 > u32_2, u32(1), u32(2)));

check("vpavgb*" + suffix, 32 * m, u8((u16(u8_1) + u16(u8_2) + 1) / 2));
check("vpavgw*" + suffix, 16 * m, u16((u32(u16_1) + u32(u16_2) + 1) / 2));
check("vpmaxsw*" + suffix, 16 * m, max(i16_1, i16_2));
check("vpminsw*" + suffix, 16 * m, min(i16_1, i16_2));
check("vpmaxub*" + suffix, 32 * m, max(u8_1, u8_2));
check("vpminub*" + suffix, 32 * m, min(u8_1, u8_2));

check("vpabsb*" + suffix, 32 * m, abs(i8_1));
check("vpabsw*" + suffix, 16 * m, abs(i16_1));
check("vpabsd*" + suffix, 8 * m, abs(i32_1));
};

check_x86_fixed_point("ymm", 1);

if (use_avx512) {
check_x86_fixed_point("zmm", 2);
}

check(use_avx512 ? "vpaddq*zmm" : "vpaddq*ymm", 8, i64_1 + i64_2);
check(use_avx512 ? "vpsubq*zmm" : "vpsubq*ymm", 8, i64_1 - i64_2);
check(use_avx512 ? "vpmullq" : "vpmuludq*ymm", 8, u64_1 * u64_2);

check("vpabsb*ymm", 32, abs(i8_1));
check("vpabsw*ymm", 16, abs(i16_1));
check("vpabsd*ymm", 8, abs(i32_1));

// llvm doesn't distinguish between signed and unsigned multiplies
// check("vpmuldq", 8, i64(i32_1) * i64(i32_2));
if (!use_avx512) {
Expand Down

0 comments on commit 5da5dfd

Please sign in to comment.