Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CONTRIB] rocBLAS integration #751

Merged
merged 3 commits into from
Jan 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ include make/contrib/cudnn.mk
include make/contrib/miopen.mk
include make/contrib/mps.mk
include make/contrib/cublas.mk
include make/contrib/rocblas.mk

ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS)
Expand Down
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ USE_MPS = 0

# Whether use cuBLAS
USE_CUBLAS = 0

# Whether use rocBlas
USE_ROCBLAS = 0
8 changes: 8 additions & 0 deletions make/contrib/rocblas.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ROCBLAS_CONTRIB_SRC = $(wildcard src/contrib/rocblas/*.cc)
ROCBLAS_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCBLAS_CONTRIB_SRC))

ifeq ($(USE_ROCBLAS), 1)
CFLAGS += -DTVM_USE_ROCBLAS=1
ADD_LDFLAGS += -lrocblas
RUNTIME_DEP += $(ROCBLAS_CONTRIB_OBJ)
endif
2 changes: 1 addition & 1 deletion python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS

This function serves as an example on how to calle external libraries.
This function serves as an example on how to call external libraries.

Parameters
----------
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/contrib/rocblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""External function interface to rocBLAS libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS

Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs

Returns
-------
C : Tensor
The result tensor.
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.rocblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
76 changes: 76 additions & 0 deletions src/contrib/rocblas/rocblas.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*!
* Copyright (c) 2017 by Contributors
* \file Use external rocblas library call.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
#include "rocblas.h"

namespace tvm {
namespace contrib {

using namespace runtime;

#ifndef CHECK_ROCBLAS_ERROR
#define CHECK_ROCBLAS_ERROR(error) \
if (error != rocblas_status_success) { \
fprintf(stderr, "rocBLAS error: "); \
if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \
if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \
if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \
if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \
if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \
if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
}
#endif


// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));

rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
float *A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float *B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float *C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);

CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle,
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? B->shape[0] : B->shape[1],
transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0],
&alpha,
A_ptr,
B->shape[1],
B_ptr,
A->shape[1],
&beta,
C_ptr,
C->shape[1]));

CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});
} // namespace contrib
} // namespace tvm
33 changes: 33 additions & 0 deletions tests/python/contrib/test_rocblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import tvm
import numpy as np
from tvm.contrib import rocblas

def test_matmul_add():
n = 1024
l = 128
m = 235
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = rocblas.matmul(A, B)
s = tvm.create_schedule(C.op)

def verify(target="rocm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
verify()


if __name__ == "__main__":
test_matmul_add()
1 change: 1 addition & 0 deletions topi/python/topi/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs

from .conv2d import *
from .dense import *
66 changes: 66 additions & 0 deletions topi/python/topi/rocm/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# pylint: disable=invalid-name, unused-variable
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic

@dense.register("rocm")
def dense_rocm(data, weight, bias=None):
"""Dense operator for rocm backend.

Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]

weight : tvm.Tensor
2-D with shape [out_dim, in_dim]

bias : tvm.Tensor, optional
1-D with shape [out_dim]

Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim dense"
if bias is not None:
assert len(bias.shape) == 1
batch, in_dim = data.shape
out_dim, _ = weight.shape
target = tvm.target.current_target()
if "rocblas" in target.libs:
matmul = rocblas.matmul(data, weight, False, True)
if bias is not None:
matmul = tvm.compute((batch, out_dim), \
lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
return dense_default(data, weight, bias)


@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
"""Schedule for dense operator.

Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for dense.
"""
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)