Skip to content

Commit

Permalink
🖊️ code review
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Feb 26, 2023
1 parent 84b6518 commit d53e09c
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 138 deletions.
117 changes: 31 additions & 86 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ trait DTypeInfo {
const NB_BITS: usize;
}

/// Macro for implementing DTypeInfo for the passed data types (uints, ints, floats)
macro_rules! impl_nb_bits {
($($t:ty)*) => ($(
impl DTypeInfo for $t {
const NB_BITS: usize = std::mem::size_of::<$t>() * 8;
// $data_type is the data type (e.g. i32)
// you can pass multiple types (separated by commas) to this macro
($($data_type:ty)*) => ($(
impl DTypeInfo for $data_type {
const NB_BITS: usize = std::mem::size_of::<$data_type>() * 8;
}
)*)
}
Expand All @@ -65,107 +68,47 @@ impl_nb_bits!(f32 f64);
#[cfg(feature = "half")]
impl_nb_bits!(f16);

// use once_cell::sync::Lazy;

// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// static AVX512BW_DETECTED: Lazy<bool> = Lazy::new(|| is_x86_feature_detected!("avx512bw"));
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// static AVX512F_DETECTED: Lazy<bool> = Lazy::new(|| is_x86_feature_detected!("avx512f"));
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// static AVX2_DETECTED: Lazy<bool> = Lazy::new(|| is_x86_feature_detected!("avx2"));
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// static AVX_DETECTED: Lazy<bool> = Lazy::new(|| is_x86_feature_detected!("avx"));
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// static SSE_DETECTED: Lazy<bool> = Lazy::new(|| is_x86_feature_detected!("sse4.1"));
// #[cfg(target_arch = "arm")]
// static NEON_DETECTED: Lazy<bool> = Lazy::new(|| std::arch::is_arm_feature_detected!("neon"));

// macro_rules! impl_argminmax {
// ($($t:ty),*) => {
// $(
// impl ArgMinMax for ArrayView1<'_, $t> {
// fn argminmax(self) -> (usize, usize) {
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// {
// if *AVX512BW_DETECTED & (<$t>::NB_BITS <= 16) {
// // BW (ByteWord) instructions are needed for 16-bit avx512
// return unsafe { AVX512::argminmax(self) }
// } else if *AVX512F_DETECTED { // TODO: check if avx512bw is included in avx512f
// return unsafe { AVX512::argminmax(self) }
// } else if *AVX2_DETECTED {
// return unsafe { AVX2::argminmax(self) }
// } else if *AVX_DETECTED & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) {
// // f32 and f64 do not require avx2
// return unsafe { AVX2::argminmax(self) }
// // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
// // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) {
// // // SSE4.2 is needed for comparing 64-bit integers
// // return unsafe { SSE::argminmax(self) }
// } else if *SSE_DETECTED & (<$t>::NB_BITS < 64) {
// // Scalar is faster for 64-bit numbers
// return unsafe { SSE::argminmax(self) }
// }
// }
// #[cfg(target_arch = "aarch64")]
// {
// if *NEON_DETECTED & (<$t>::NB_BITS < 64) {
// // We miss some NEON instructions for 64-bit numbers
// return unsafe { NEON::argminmax(self) }
// }
// }
// #[cfg(target_arch = "arm")]
// {
// if *NEON_DETECTED & (<$t>::NB_BITS < 64) {
// // TODO: requires v7?
// // We miss some NEON instructions for 64-bit numbers
// return unsafe { NEON::argminmax(self) }
// }
// }
// SCALAR::argminmax(self)
// }
// }
// )*
// };
// }

// ------------------------------ &[T] ------------------------------

/// Macro for implementing ArgMinMax for signed and unsigned integers
macro_rules! impl_argminmax_non_float {
($($t:ty),*) => {
// $int_type is the integer data type of the array (e.g. i32)
// you can pass multiple types (separated by commas) to this macro
($($int_type:ty),*) => {
$(
impl ArgMinMax for &[$t] {
impl ArgMinMax for &[$int_type] {
fn argminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS == 8) {
if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS == 8) {
// 8-bit numbers are best handled by SSE4.1
return unsafe { SSE::argminmax(self) }
} else if is_x86_feature_detected!("avx512bw") & (<$t>::NB_BITS <= 16) {
} else if is_x86_feature_detected!("avx512bw") & (<$int_type>::NB_BITS <= 16) {
// BW (ByteWord) instructions are needed for 8 or 16-bit avx512
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx512f") { // TODO: check if avx512bw is included in avx512f
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::argminmax(self) }
// SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
// // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) {
// // } else if is_x86_feature_detected!("sse4.2") & (<$int_type>::NB_BITS == 64) & (<$int_type>::IS_FLOAT == false) {
// // SSE4.2 is needed for comparing 64-bit integers
// return unsafe { SSE::argminmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS < 64) {
} else if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS < 64) {
// Scalar is faster for 64-bit numbers
return unsafe { SSE::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_aarch64_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEON::argminmax(self) }
}
}
#[cfg(target_arch = "arm")]
{
if std::arch::is_arm_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_arm_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) {
// TODO: requires v7?
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEON::argminmax(self) }
Expand All @@ -183,41 +126,43 @@ macro_rules! impl_argminmax_non_float {
};
}

// Macro for implementing ArgMinMax for floats (f32, f64)
/// Macro for implementing ArgMinMax for floats
macro_rules! impl_argminmax_float {
($($t:ty),*) => {
// $float_type is the float data type of the array (e.g. f32)
// you can pass multiple types (separated by commas) to this macro
($($float_type:ty),*) => {
$(
impl ArgMinMax for &[$t] {
impl ArgMinMax for &[$float_type] {
fn nanargminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS == 8) {
if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS == 8) {
// 8-bit numbers are best handled by SSE4.1
return unsafe { SSE::argminmax(self) }
} else if is_x86_feature_detected!("avx512bw") & (<$t>::NB_BITS <= 16) {
} else if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS <= 16) {
// BW (ByteWord) instructions are needed for 8 or 16-bit avx512
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx512f") { // TODO: check if avx512bw is included in avx512f
return unsafe { AVX512::argminmax(self) }
} else if is_x86_feature_detected!("avx2") {
return unsafe { AVX2::argminmax(self) }
// SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers
} else if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS < 64) {
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
// Scalar is faster for 64-bit numbers
// TODO: double check this (observed different things for new float implementation)
return unsafe { SSE::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEON::argminmax(self) }
}
}
#[cfg(target_arch = "arm")]
{
if std::arch::is_arm_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// TODO: requires v7?
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEON::argminmax(self) }
Expand All @@ -228,29 +173,29 @@ macro_rules! impl_argminmax_float {
fn argminmax(&self) -> (usize, usize) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if <$t>::NB_BITS <= 16 {
if <$float_type>::NB_BITS <= 16 {
// TODO: f16 IgnoreNaN is not yet SIMD-optimized
// do nothing (defaults to scalar)
} else if is_x86_feature_detected!("avx512f") {
return unsafe { AVX512IgnoreNaN::argminmax(self) }
} else if is_x86_feature_detected!("avx") {
// f32 and f64 do not require avx2
return unsafe { AVX2IgnoreNaN::argminmax(self) }
} else if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS < 64) {
} else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) {
// Scalar is faster for 64-bit numbers
return unsafe { SSEIgnoreNaN::argminmax(self) }
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEONIgnoreNaN::argminmax(self) }
}
}
#[cfg(target_arch = "arm")]
{
if std::arch::is_arm_feature_detected!("neon") & (<$t>::NB_BITS < 64) {
if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) {
// TODO: requires v7?
// We miss some NEON instructions for 64-bit numbers
return unsafe { NEONIgnoreNaN::argminmax(self) }
Expand Down
56 changes: 53 additions & 3 deletions src/simd/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// https://github.com/rust-lang/portable-simd/blob/master/beginners-guide.md#target-features

/// SIMD instruction set trait - used to store the register size and get the lane size
/// for a given datatype
pub trait SIMDInstructionSet {
/// The size of the register in bits
const REGISTER_SIZE: usize;

// Set the const lanesize for each datatype
Expand All @@ -14,38 +17,85 @@ pub trait SIMDInstructionSet {
}
}

// ----------------------------- x86_64 / x86 -----------------------------
// ----------------------------------- x86_64 / x86 ------------------------------------

/// SSE instruction set - this will be implemented for all:
/// - ints (see, the simd_i*.rs files)
/// - uints (see, the simd_u*.rs files)
/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files)
pub struct SSE;
/// SSE instruction set - this will be implemented for all:
/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files)
pub struct SSEIgnoreNaN;

impl SIMDInstructionSet for SSE {
/// SSE register size is 128 bits
/// https://en.wikipedia.org/wiki/Streaming_SIMD_Extensions#Registers
const REGISTER_SIZE: usize = 128;
}

pub struct AVX2; // for f32 and f64 AVX is enough
/// AVX2 instruction set - this will be implemented for all:
/// - ints (see, the simd_i*.rs files)
/// - uints (see, the simd_u*.rs files)
/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files)
pub struct AVX2;

/// AVX(2) instruction set - this will be implemented for all:
/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files)
///
/// Important remark: AVX is enough for f32 and f64!
/// -> for f16 we need AVX2 - but this is currently not yet implemented (TODO)
///
/// Note: this struct does not implement the `SIMDInstructionSet` trait
pub struct AVX2IgnoreNaN;

impl SIMDInstructionSet for AVX2 {
/// AVX(2) register size is 256 bits
/// AVX: https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#Advanced_Vector_Extensions
/// AVX2: https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#AVX2
const REGISTER_SIZE: usize = 256;
}

/// AVX512 instruction set - this will be implemented for all:
/// - ints (see, the simd_i*.rs files)
/// - uints (see, the simd_u*.rs files)
/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files)
pub struct AVX512;

/// AVX512 instruction set - this will be implemented for all:
/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files)
///
/// Note: this struct does not implement the `SIMDInstructionSet` trait
pub struct AVX512IgnoreNaN;

impl SIMDInstructionSet for AVX512 {
/// AVX512 register size is 512 bits
/// https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#AVX-512
const REGISTER_SIZE: usize = 512;
}

// ----------------------------- aarch64 / arm -----------------------------
// ----------------------------------- aarch64 / arm -----------------------------------

/// NEON instruction set - this will be implemented for all:
/// - ints (see, the simd_i*.rs files)
/// - uints (see, the simd_u*.rs files)
/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files)
pub struct NEON;

/// NEON instruction set - this will be implemented for all:
/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files)
///
/// Note: this struct does not implement the `SIMDInstructionSet` trait
pub struct NEONIgnoreNaN;

impl SIMDInstructionSet for NEON {
/// NEON register size is 128 bits
/// https://en.wikipedia.org/wiki/ARM_architecture#Advanced_SIMD_(Neon)
const REGISTER_SIZE: usize = 128;
}

// --------------------------------------- Tests ---------------------------------------

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit d53e09c

Please sign in to comment.