Skip to content

Commit

Permalink
[CONTRIB] rocBLAS integration (apache#751)
Browse files Browse the repository at this point in the history
* rocblas integration

* fix include

* fix lint
  • Loading branch information
masahi authored and sergei-mironov committed Aug 8, 2018
1 parent a3a4b12 commit eba7c13
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 1 deletion.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ include make/contrib/cudnn.mk
include make/contrib/miopen.mk
include make/contrib/mps.mk
include make/contrib/cublas.mk
include make/contrib/rocblas.mk

ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS)
Expand Down
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ USE_MPS = 0

# Whether use cuBLAS
USE_CUBLAS = 0

# Whether use rocBlas
USE_ROCBLAS = 0
8 changes: 8 additions & 0 deletions make/contrib/rocblas.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ROCBLAS_CONTRIB_SRC = $(wildcard src/contrib/rocblas/*.cc)
ROCBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCBLAS_CONTRIB_SRC))

ifeq ($(USE_ROCBLAS), 1)
CFLAGS += -DTVM_USE_ROCBLAS=1
ADD_LDFLAGS += -lrocblas
RUNTIME_DEP += $(ROCBLAS_CONTRIB_OBJ)
endif
2 changes: 1 addition & 1 deletion python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to calle external libraries.
This function serves as an example on how to call external libraries.
Parameters
----------
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/contrib/rocblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""External function interface to rocBLAS libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.rocblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
76 changes: 76 additions & 0 deletions src/contrib/rocblas/rocblas.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*!
* Copyright (c) 2017 by Contributors
* \file Use external rocblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "rocblas.h"

namespace tvm {
namespace contrib {

using namespace runtime;

#ifndef CHECK_ROCBLAS_ERROR
#define CHECK_ROCBLAS_ERROR(error) \
if (error != rocblas_status_success) { \
fprintf(stderr, "rocBLAS error: "); \
if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \
if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \
if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \
if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \
if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \
if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
}
#endif


// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));

rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);

CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle,
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
&alpha,
A_ptr,
B->shape[1],
B_ptr,
A->shape[1],
&beta,
C_ptr,
C->shape[1]));

CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
} // namespace contrib
} // namespace tvm
33 changes: 33 additions & 0 deletions tests/python/contrib/test_rocblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import tvm
import numpy as np
from tvm.contrib import rocblas

def test_matmul_add():
n = 1024
l = 128
m = 235
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = rocblas.matmul(A, B)
s = tvm.create_schedule(C.op)

def verify(target="rocm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()


if __name__ == "__main__":
test_matmul_add()
1 change: 1 addition & 0 deletions topi/python/topi/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs

from .conv2d import *
from .dense import *
66 changes: 66 additions & 0 deletions topi/python/topi/rocm/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic

@dense.register("rocm")
def dense_rocm(data, weight, bias=None):
"""Dense operator for rocm backend.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.Tensor, optional
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "rocblas" in target.libs:
matmul = rocblas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)


@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
"""Schedule for dense operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for dense.
"""
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)

0 comments on commit eba7c13

Please sign in to comment.