From d5c40736856263e764e276325e5d6b5524df8b71 Mon Sep 17 00:00:00 2001 From: sayantn Date: Wed, 19 Jun 2024 20:07:00 +0530 Subject: [PATCH] Added runtime detection Cannot do a `cupid` test because they don't support `amx`. --- crates/std_detect/src/detect/arch/x86.rs | 15 ++++++++ crates/std_detect/src/detect/os/x86.rs | 11 ++++++ crates/std_detect/tests/cpu-detection.rs | 7 +++- crates/std_detect/tests/x86-specific.rs | 48 ++++++++++++++---------- 4 files changed, 61 insertions(+), 20 deletions(-) diff --git a/crates/std_detect/src/detect/arch/x86.rs b/crates/std_detect/src/detect/arch/x86.rs index 8867c59b11..ef7091ec7b 100644 --- a/crates/std_detect/src/detect/arch/x86.rs +++ b/crates/std_detect/src/detect/arch/x86.rs @@ -84,6 +84,11 @@ features! { /// * `"avxneconvert"` /// * `"avxvnniint8"` /// * `"avxvnniint16"` + /// * `"amx-tile"` + /// * `"amx-int8"` + /// * `"amx-bf16"` + /// * `"amx-fp16"` + /// * `"amx-complex"` /// * `"f16c"` /// * `"fma"` /// * `"bmi1"` @@ -196,6 +201,16 @@ features! { /// AVX-VNNI_INT8 (VNNI with 16-bit Integers) @FEATURE: #[unstable(feature = "avx512_target_feature", issue = "44839")] avxvnniint8: "avxvnniint8"; /// AVX-VNNI_INT16 (VNNI with 8-bit integers) + @FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_tile: "amx-tile"; + /// AMX (Advanced Matrix Extensions) - Tile load/store + @FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_int8: "amx-int8"; + /// AMX-INT8 (Operations on 8-bit integers) + @FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_bf16: "amx-bf16"; + /// AMX-BF16 (BFloat16 Operations) + @FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_fp16: "amx-fp16"; + /// AMX-FP16 (Float16 Operations) + @FEATURE: #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] amx_complex: "amx-complex"; + /// AMX-COMPLEX (Complex number Operations) @FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] f16c: "f16c"; /// F16C (Conversions between IEEE-754 `binary16` and `binary32` formats) @FEATURE: #[stable(feature = "simd_x86", since = "1.27.0")] fma: "fma"; diff --git a/crates/std_detect/src/detect/os/x86.rs b/crates/std_detect/src/detect/os/x86.rs index c0d86d4b3a..4a0e2c9b18 100644 --- a/crates/std_detect/src/detect/os/x86.rs +++ b/crates/std_detect/src/detect/os/x86.rs @@ -164,6 +164,7 @@ pub(crate) fn detect_features() -> cache::Initializer { // * SSE -> `XCR0.SSE[1]` // * AVX -> `XCR0.AVX[2]` // * AVX-512 -> `XCR0.AVX-512[7:5]`. + // * AMX -> `XCR0.AMX[18:17]` // // by setting the corresponding bits of `XCR0` to `1`. // @@ -174,6 +175,8 @@ pub(crate) fn detect_features() -> cache::Initializer { let os_avx_support = xcr0 & 6 == 6; // Test `XCR0.AVX-512[7:5]` with the mask `0b1110_0000 == 0xe0`: let os_avx512_support = xcr0 & 0xe0 == 0xe0; + // Test `XCR0.AMX[18:17]` with the mask `0b110_0000_0000_0000_0000 == 0x60000` + let os_amx_support = xcr0 & 0x60000 == 0x60000; // Only if the OS and the CPU support saving/restoring the AVX // registers we enable `xsave` support: @@ -240,6 +243,14 @@ pub(crate) fn detect_features() -> cache::Initializer { enable(extended_features_edx, 8, Feature::avx512vp2intersect); enable(extended_features_edx, 23, Feature::avx512fp16); enable(extended_features_eax_leaf_1, 5, Feature::avx512bf16); + + if os_amx_support { + enable(extended_features_edx, 24, Feature::amx_tile); + enable(extended_features_edx, 25, Feature::amx_int8); + enable(extended_features_edx, 22, Feature::amx_bf16); + enable(extended_features_eax_leaf_1, 21, Feature::amx_fp16); + enable(extended_features_edx_leaf_1, 8, Feature::amx_complex); + } } } } diff --git a/crates/std_detect/tests/cpu-detection.rs b/crates/std_detect/tests/cpu-detection.rs index 6cf74a6721..47bb499dee 100644 --- a/crates/std_detect/tests/cpu-detection.rs +++ b/crates/std_detect/tests/cpu-detection.rs @@ -5,7 +5,7 @@ #![cfg_attr(target_arch = "powerpc64", feature(stdarch_powerpc_feature_detection))] #![cfg_attr( any(target_arch = "x86", target_arch = "x86_64"), - feature(sha512_sm_x86) + feature(sha512_sm_x86, x86_amx_intrinsics) )] #![allow(clippy::unwrap_used, clippy::use_debug, clippy::print_stdout)] @@ -259,6 +259,11 @@ fn x86_all() { println!("xsaveopt: {:?}", is_x86_feature_detected!("xsaveopt")); println!("xsaves: {:?}", is_x86_feature_detected!("xsaves")); println!("xsavec: {:?}", is_x86_feature_detected!("xsavec")); + println!("amx-bf16: {:?}", is_x86_feature_detected!("amx-bf16")); + println!("amx-tile: {:?}", is_x86_feature_detected!("amx-tile")); + println!("amx-int8: {:?}", is_x86_feature_detected!("amx-int8")); + println!("amx-fp16: {:?}", is_x86_feature_detected!("amx-fp16")); + println!("amx-complex: {:?}", is_x86_feature_detected!("amx-complex")); } #[test] diff --git a/crates/std_detect/tests/x86-specific.rs b/crates/std_detect/tests/x86-specific.rs index 74326f4a5a..611e41c941 100644 --- a/crates/std_detect/tests/x86-specific.rs +++ b/crates/std_detect/tests/x86-specific.rs @@ -1,6 +1,11 @@ #![cfg(any(target_arch = "x86", target_arch = "x86_64"))] #![allow(internal_features)] -#![feature(stdarch_internal, avx512_target_feature, sha512_sm_x86)] +#![feature( + stdarch_internal, + avx512_target_feature, + sha512_sm_x86, + x86_amx_intrinsics +)] extern crate cupid; #[macro_use] @@ -27,34 +32,34 @@ fn dump() { println!("sha512: {:?}", is_x86_feature_detected!("sha512")); println!("sm3: {:?}", is_x86_feature_detected!("sm3")); println!("sm4: {:?}", is_x86_feature_detected!("sm4")); - println!("avx512f {:?}", is_x86_feature_detected!("avx512f")); - println!("avx512cd {:?}", is_x86_feature_detected!("avx512cd")); - println!("avx512er {:?}", is_x86_feature_detected!("avx512er")); - println!("avx512pf {:?}", is_x86_feature_detected!("avx512pf")); - println!("avx512bw {:?}", is_x86_feature_detected!("avx512bw")); - println!("avx512dq {:?}", is_x86_feature_detected!("avx512dq")); - println!("avx512vl {:?}", is_x86_feature_detected!("avx512vl")); - println!("avx512_ifma {:?}", is_x86_feature_detected!("avx512ifma")); + println!("avx512f: {:?}", is_x86_feature_detected!("avx512f")); + println!("avx512cd: {:?}", is_x86_feature_detected!("avx512cd")); + println!("avx512er: {:?}", is_x86_feature_detected!("avx512er")); + println!("avx512pf: {:?}", is_x86_feature_detected!("avx512pf")); + println!("avx512bw: {:?}", is_x86_feature_detected!("avx512bw")); + println!("avx512dq: {:?}", is_x86_feature_detected!("avx512dq")); + println!("avx512vl: {:?}", is_x86_feature_detected!("avx512vl")); + println!("avx512_ifma: {:?}", is_x86_feature_detected!("avx512ifma")); println!("avx512vbmi {:?}", is_x86_feature_detected!("avx512vbmi")); println!( - "avx512_vpopcntdq {:?}", + "avx512_vpopcntdq: {:?}", is_x86_feature_detected!("avx512vpopcntdq") ); - println!("avx512vbmi2 {:?}", is_x86_feature_detected!("avx512vbmi2")); - println!("gfni {:?}", is_x86_feature_detected!("gfni")); - println!("vaes {:?}", is_x86_feature_detected!("vaes")); - println!("vpclmulqdq {:?}", is_x86_feature_detected!("vpclmulqdq")); - println!("avx512vnni {:?}", is_x86_feature_detected!("avx512vnni")); + println!("avx512vbmi2: {:?}", is_x86_feature_detected!("avx512vbmi2")); + println!("gfni: {:?}", is_x86_feature_detected!("gfni")); + println!("vaes: {:?}", is_x86_feature_detected!("vaes")); + println!("vpclmulqdq: {:?}", is_x86_feature_detected!("vpclmulqdq")); + println!("avx512vnni: {:?}", is_x86_feature_detected!("avx512vnni")); println!( - "avx512bitalg {:?}", + "avx512bitalg: {:?}", is_x86_feature_detected!("avx512bitalg") ); - println!("avx512bf16 {:?}", is_x86_feature_detected!("avx512bf16")); + println!("avx512bf16: {:?}", is_x86_feature_detected!("avx512bf16")); println!( - "avx512vp2intersect {:?}", + "avx512vp2intersect: {:?}", is_x86_feature_detected!("avx512vp2intersect") ); - println!("avx512fp16 {:?}", is_x86_feature_detected!("avx512fp16")); + println!("avx512fp16: {:?}", is_x86_feature_detected!("avx512fp16")); println!("fma: {:?}", is_x86_feature_detected!("fma")); println!("abm: {:?}", is_x86_feature_detected!("abm")); println!("bmi: {:?}", is_x86_feature_detected!("bmi1")); @@ -82,6 +87,11 @@ fn dump() { "avxvnniint16: {:?}", is_x86_feature_detected!("avxvnniint16") ); + println!("amx-bf16: {:?}", is_x86_feature_detected!("amx-bf16")); + println!("amx-tile: {:?}", is_x86_feature_detected!("amx-tile")); + println!("amx-int8: {:?}", is_x86_feature_detected!("amx-int8")); + println!("amx-fp16: {:?}", is_x86_feature_detected!("amx-fp16")); + println!("amx-complex: {:?}", is_x86_feature_detected!("amx-complex")); } #[cfg(feature = "std_detect_env_override")]