forked from whchung/mlir_roundtrip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
matmul_lib.cpp
56 lines (51 loc) · 1.75 KB
/
matmul_lib.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cblas.h>
#include <mlir_runner_utils.h>
extern "C" void linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
*(X->data + X->offset + i * X->strides[0]) = f;
}
static void sgemm(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB,
const int M, const int N, const int K,
const float alpha,
const float *A, const int lda,
const float *B, const int ldb,
const float beta,
float *C, const int ldc) {
for (int m = 0; m < M; ++m) {
auto *pA = A + m * lda;
auto *pC = C + m * ldc;
for (int n = 0; n < N; ++n) {
float c = pC[n];
float res = 0.0f;
for (int k = 0; k < K; ++k) {
auto *pB = B + k * ldb;
res += pA[k] * pB[n];
}
pC[n] = alpha * c + beta * res;
}
}
}
__attribute__((always_inline)) extern "C" void
linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C) {
//std::cout << std::endl;
//printMemRefMetaData(std::cerr, *A);
//std::cout << std::endl;
//printMemRefMetaData(std::cerr, *B);
//std::cout << std::endl;
//printMemRefMetaData(std::cerr, *C);
//std::cout << std::endl;
sgemm(CBLAS_ORDER::CblasRowMajor,
CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasNoTrans,
C->sizes[0], C->sizes[1], A->sizes[1],
1.0f,
A->data + A->offset, A->strides[0],
B->data + B->offset, B->strides[0],
1.0f,
C->data + C->offset, C->strides[0]);
}