Skip to content

Commit

Permalink
Merge pull request #36 from SuperFluffy/dgemm_fma
Browse files Browse the repository at this point in the history
Implement sgemm and dgemm using fma
  • Loading branch information
bluss authored Dec 7, 2018
2 parents a9c195b + 6fdfa19 commit 20932b3
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 24 deletions.
9 changes: 8 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
os: linux
language: rust
sudo: false
sudo: yes
dist: trusty

matrix:
include:
Expand All @@ -22,6 +24,11 @@ matrix:
env:
TARGET=x86_64-unknown-linux-gnu
MMNO_avx=1
MMNO_fma=1
- rust: nightly
env:
TARGET=x86_64-unknown-linux-gnu
MMNO_fma=1
- rust: nightly
env:
TARGET=aarch64-unknown-linux-gnu
Expand Down
77 changes: 56 additions & 21 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@ macro_rules! loop_n {
($j:ident, $e:expr) => { loop4!($j, $e) };
}


#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelFma;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelAvx;

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
trait DgemmMultiplyAdd {
unsafe fn multiply_add(__m256d, __m256d, __m256d) -> __m256d;
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl DgemmMultiplyAdd for KernelAvx {
#[inline(always)]
unsafe fn multiply_add(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_add_pd(_mm256_mul_pd(a, b), c)
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl DgemmMultiplyAdd for KernelFma {
#[inline(always)]
unsafe fn multiply_add(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
_mm256_fmadd_pd(a, b, c)
}
}

impl GemmKernel for Gemm {
type Elem = T;

Expand Down Expand Up @@ -83,7 +110,9 @@ pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T,
// dispatch to specific compiled versions
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
{
if is_x86_feature_detected_!("avx") {
if is_x86_feature_detected_!("fma") {
return kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc);
} else if is_x86_feature_detected_!("avx") {
return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc);
} else if is_x86_feature_detected_!("sse2") {
return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc);
Expand All @@ -95,25 +124,35 @@ pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T,
#[inline]
#[target_feature(enable="avx")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
pub unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_x86_avx::<KernelFma>(k, alpha, a, b, beta, c, rsc, csc)
}

#[inline]
#[target_feature(enable="avx")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_x86_avx(k, alpha, a, b, beta, c, rsc, csc)
kernel_x86_avx::<KernelAvx>(k, alpha, a, b, beta, c, rsc, csc)
}

#[inline]
#[target_feature(enable="sse2")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
pub unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
}

#[inline(always)]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,
unsafe fn kernel_x86_avx<DMA>(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
where DMA: DgemmMultiplyAdd
{
debug_assert_ne!(k, 0);

Expand Down Expand Up @@ -201,15 +240,15 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,
// ab_7 || a4 b3 | a5 b2 | a6 b1 | a7 b0

// Add and multiply in one go
ab[0] = _mm256_add_pd(ab[0], _mm256_mul_pd(a_0123, b_0123));
ab[1] = _mm256_add_pd(ab[1], _mm256_mul_pd(a_0123, b_1032));
ab[2] = _mm256_add_pd(ab[2], _mm256_mul_pd(a_0123, b_2301));
ab[3] = _mm256_add_pd(ab[3], _mm256_mul_pd(a_0123, b_3210));
ab[0] = DMA::multiply_add(a_0123, b_0123, ab[0]);
ab[1] = DMA::multiply_add(a_0123, b_1032, ab[1]);
ab[2] = DMA::multiply_add(a_0123, b_2301, ab[2]);
ab[3] = DMA::multiply_add(a_0123, b_3210, ab[3]);

ab[4] = _mm256_add_pd(ab[4], _mm256_mul_pd(a_4567, b_0123));
ab[5] = _mm256_add_pd(ab[5], _mm256_mul_pd(a_4567, b_1032));
ab[6] = _mm256_add_pd(ab[6], _mm256_mul_pd(a_4567, b_2301));
ab[7] = _mm256_add_pd(ab[7], _mm256_mul_pd(a_4567, b_3210));
ab[4] = DMA::multiply_add(a_4567, b_0123, ab[4]);
ab[5] = DMA::multiply_add(a_4567, b_1032, ab[5]);
ab[6] = DMA::multiply_add(a_4567, b_2301, ab[6]);
ab[7] = DMA::multiply_add(a_4567, b_3210, ab[7]);

if !is_last {
a = a.add(MR);
Expand Down Expand Up @@ -589,9 +628,6 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,

// Compute α (A B)
// _mm256_set1_pd and _mm256_broadcast_sd seem to achieve the same thing.
let alpha_v = _mm256_broadcast_sd(&alpha);
loop_m!(i, ab[i] = _mm256_mul_pd(alpha_v, ab[i]));

macro_rules! c {
($i:expr, $j:expr) =>
(c.offset(rsc * $i as isize + csc * $j as isize));
Expand All @@ -601,9 +637,6 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,
let mut cv = [_mm256_setzero_pd(); MR];

if beta != 0. {
// _mm256_set1_pd and _mm256_broadcast_sd seem to achieve the same thing.
let beta_v = _mm256_broadcast_sd(&beta);

// Read C
if rsc == 1 {
loop4!(i, cv[i] = _mm256_loadu_pd(c![0, i]));
Expand All @@ -626,11 +659,14 @@ unsafe fn kernel_x86_avx(k: usize, alpha: T, a: *const T, b: *const T,
));
}
// Compute β C
// _mm256_set1_pd and _mm256_broadcast_sd seem to achieve the same thing.
let beta_v = _mm256_broadcast_sd(&beta);
loop_m!(i, cv[i] = _mm256_mul_pd(cv[i], beta_v));
}

// Compute (α A B) + (β C)
loop_m!(i, cv[i] = _mm256_add_pd(cv[i], ab[i]));
let alpha_v = _mm256_broadcast_sd(&alpha);
loop_m!(i, cv[i] = DMA::multiply_add(alpha_v, ab[i], cv[i]));

if rsc == 1 {
loop4!(i, _mm256_storeu_pd(c![0, i], cv[i]));
Expand Down Expand Up @@ -774,7 +810,6 @@ mod tests {
}

test_arch_kernels_x86! {
"avx", kernel_target_avx,
"sse2", kernel_target_sse2
}
}
Expand Down
Loading

0 comments on commit 20932b3

Please sign in to comment.