Skip to content

Commit

Permalink
[ROCm] Fix dense autotvm template registration (#3136)
Browse files Browse the repository at this point in the history
* Fix rocm dense autotvm template

* suppres lint warning
  • Loading branch information
masahi authored and vinx13 committed May 5, 2019
1 parent 094fc68 commit 31ba013
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
Expand Down
13 changes: 7 additions & 6 deletions topi/python/topi/rocm/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic

@dense.register("rocm")
def dense_rocm(data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_compute(dense, "rocm", "direct")
def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend.
Parameters
Expand Down Expand Up @@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None):
return dense_default(data, weight, bias, out_dtype)


@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct")
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
Expand All @@ -85,4 +86,4 @@ def schedule_dense(outs):
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)
return topi.cuda.schedule_dense(cfg, outs)

0 comments on commit 31ba013

Please sign in to comment.