Skip to content
This repository has been archived by the owner on Jul 26, 2024. It is now read-only.

Commit

Permalink
try non fortran indexing inside kgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcdanie committed Jul 14, 2020
1 parent f964af3 commit b02f239
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 42 deletions.
29 changes: 8 additions & 21 deletions kgemm_nn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ DEVICE_FUNCTION void kgemm_nn(int const mm, int const nn, int const kk,
#endif

auto A = [&](int const ia, int const ja) -> T const & {
return (A_[indx2f(ia, ja, ldA)]);
return (A_[indx2(ia, ja, ldA)]);
};

auto B = [&](int const ib, int const jb) -> T const & {
return (B_[indx2f(ib, jb, ldB)]);
return (B_[indx2(ib, jb, ldB)]);
};

auto C = [&](int const ic, int const jc) -> T & {
return (C_[indx2f(ic, jc, ldC)]);
return (C_[indx2(ic, jc, ldC)]);
};

// ---------------------------
Expand All @@ -54,28 +54,15 @@ DEVICE_FUNCTION void kgemm_nn(int const mm, int const nn, int const kk,


for (int ij0 = ij_start; ij0 < (mm * nn); ij0 += ij_size) {
int const i = (ij0 % mm) + 1;
int const j = ((ij0 - (i - 1)) / mm) + 1;
int const i = ij0 % mm;
int const j = (ij0 - i) / mm;

T cij = 0;
bool constexpr use_pointer = true;
if (use_pointer) {

int k = 1;
T const *Ap = &(A(i, k));
int64_t inc_A = &(A(i, k + 1)) - Ap;
T const *Bp = &(B(k, j));
int64_t inc_B = &(B(k + 1, j)) - Bp;
for (k = 0; k < kk; k++) {
cij += (*Ap) * (*Bp);
Ap += inc_A;
Bp += inc_B;
};
} else {
for (int k = 1; k <= kk; k++) {

for (int k = 0; k < kk; k++) {
cij += A(i, k) * B(k, j);
};
};


// ------------------
// store results to C
Expand Down
29 changes: 8 additions & 21 deletions kgemm_nt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,43 +45,30 @@ DEVICE_FUNCTION void kgemm_nt(int const mm, int const nn, int const kk,
// ------------------------------------

auto A = [&](int const ia, int const ja) -> T const & {
return (A_[indx2f(ia, ja, ldA)]);
return (A_[indx2(ia, ja, ldA)]);
};

auto B = [&](int const ib, int const jb) -> T const & {
return (B_[indx2f(ib, jb, ldB)]);
return (B_[indx2(ib, jb, ldB)]);
};

auto C = [&](int const ic, int const jc) -> T & {
return (C_[indx2f(ic, jc, ldC)]);
return (C_[indx2(ic, jc, ldC)]);
};

// ---------------------------
// perform matrix calculations
// ---------------------------

for (int ij0 = ij_start; ij0 < (mm * nn); ij0 += ij_size) {
int const i = (ij0 % mm) + 1;
int const j = (ij0 - (i - 1)) / mm + 1;
int const i = ij0 % mm;
int const j = (ij0 - i) / mm;
T cij = 0;
bool constexpr use_pointer = true;
if (use_pointer) {
int k = 1;

T const *Ap = &(A(i, k));
int64_t const inc_A = &(A(i, k + 1)) - Ap;
T const *Bp = &(B(j, k));
int64_t const inc_B = &(B(j, k + 1)) - Bp;
for (k = 0; k < kk; k++) {
cij += (*Ap) * (*Bp);
Ap += inc_A;
Bp += inc_B;
};
} else {
for (int k = 1; k <= kk; k++) {

for (int k = 0; k < kk; k++) {
cij += A(i, k) * B(j, k);
};
};


// ------------------
// store results to C
Expand Down
10 changes: 10 additions & 0 deletions kroncommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ float atomicAdd( float volatile *p, float dvalue)
#endif


static inline
HOST_FUNCTION DEVICE_FUNCTION
int indx2( int const i,
int const j,
int const ld )
{
return( i + j*ld );
}


static inline
HOST_FUNCTION DEVICE_FUNCTION
int indx2f( int const i,
Expand Down

0 comments on commit b02f239

Please sign in to comment.