diff --git a/tests/python/sparsetir/bench_rgcn.py b/tests/python/sparsetir/bench_rgcn.py index f0a2d73129bfb..b73e929c567f7 100644 --- a/tests/python/sparsetir/bench_rgcn.py +++ b/tests/python/sparsetir/bench_rgcn.py @@ -1,3 +1,5 @@ +from sys import base_exec_prefix +from termios import ECHOE from dgl.heterograph import DGLHeteroGraph import tvm import tvm.testing @@ -13,6 +15,19 @@ from sparse_tir_scripts import rgcn_forward +def get_dataset_by_name(name: str): + if name == 'aifb': + return AIFBDataset() + elif name == 'mutag': + return MUTAGDataset() + elif name == 'bgs': + return BGSDataset() + elif name == 'am': + return AMDataset() + else: + raise KeyError("Unknown dataset {}.".format(name)) + + class TorchOpTimer(object): def __enter__(self): self.start_event = th.cuda.Event(enable_timing=True) @@ -23,10 +38,10 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.end_event.record() th.cuda.synchronize() # Wait for the events to be recorded! - self.time = self.start_event.elapsed_time(self.end_event) / 1e3 + self.time = self.start_event.elapsed_time(self.end_event) -def prepare_hetero_graph_simplified(g): +def prepare_hetero_graph_simplified(g: dgl.DGLHeteroGraph): ntype_pointer = np.cumsum( [0] + [g.number_of_nodes(ntype) for ntype in g.ntypes]) @@ -38,7 +53,7 @@ def prepare_hetero_graph_simplified(g): return{"ntype_node_pointer": th.IntTensor(ntype_pointer).cuda(), "etype_edge_pointer": th.IntTensor(etype_pointer).cuda()} -def prepare_graph(g, ntype=None): +def prepare_graph(g: dgl.DGLHeteroGraph, ntype=None): # Todo: long int kernel support, multiple devices g.node_id = g.nodes(ntype).type(th.IntTensor).cuda() @@ -65,16 +80,15 @@ def prepare_graph(g, ntype=None): return g -def test_rgcn(g: DGLHeteroGraph): - feat_size = 16 +def test_rgcn(g: DGLHeteroGraph, feat_size: int): g = g.to(0) - feat = th.rand(g.num_src_nodes(), feat_size).to(0) - y = th.zeros(g.num_dst_nodes(), feat_size).to(0) + feat = th.rand(g.num_src_nodes(), feat_size).to(0) / 100 + out = th.zeros(g.num_dst_nodes(), feat_size).to(0) / 100 weight = th.rand(g.num_rels, feat_size, feat_size).to(0) indptr, indices, eid = g.adj_sparse(fmt='csc') etype = g.edata[dgl.ETYPE][eid] - # naive + # dgl-bmm cold_start = 3 total = 10 @@ -85,38 +99,24 @@ def msg_func(edges): W = weight[edges.data[dgl.ETYPE]] return {'msg': W @ h} - for epoch in range(10): - with TorchOpTimer() as timer: - g.srcdata['feat'] = feat.unsqueeze(-1) - g.update_all( - msg_func, - fn.sum('msg', 'y') - ) - y = g.dstdata['y'].sum(1) - if epoch >= cold_start: - accum += timer.time - - print("dgl low-mem:\t\t", accum / (total - cold_start)) - - # broadcast - - cold_start = 3 - total = 10 - accum = 0 - - for epoch in range(10): - with TorchOpTimer() as timer: - g.srcdata['feat'] = feat.unsqueeze(-1) - g.edata['W'] = weight[g.edata[dgl.ETYPE]] - g.update_all( - fn.u_mul_e('feat', 'W', 'msg'), - fn.sum('msg', 'y') - ) - y = g.dstdata['y'].sum(1) - if epoch >= cold_start: - accum += timer.time - - print("dgl high-mem:\t\t", accum / (total - cold_start)) + try: + for epoch in range(10): + with TorchOpTimer() as timer: + g.srcdata['feat'] = feat.unsqueeze(-1) + g.update_all( + msg_func, + fn.sum('msg', 'y') + ) + y_dgl = g.dstdata['y'].squeeze(-1) + if epoch >= cold_start: + accum += timer.time + print("dgl-bmm:\t\t {}ms".format(accum / (total - cold_start))) + except RuntimeError as err: + print("dgl-bmm: OOM") + y_dgl = None + except BaseException as err: + print(err) + raise # tir mod = tvm.IRModule.from_expr(rgcn_forward) @@ -137,14 +137,12 @@ def msg_func(edges): sch.bind(i, "blockIdx.x") sch.bind(f_out, "threadIdx.y") sch.bind(f_in, "threadIdx.x") - # sch.cache_read(inner, 4, "shared") f = tvm.build(sch.mod, target='cuda') - # print(f.imported_modules[0].get_source()) E = tvm.nd.array(etype.cpu().numpy().astype("int32"), device=tvm.cuda(0)) W = tvm.nd.array(weight.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) X = tvm.nd.array(feat.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) - Y = tvm.nd.array(y.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + Y = tvm.nd.array(out.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) indptr = tvm.nd.array(indptr.cpu().numpy().astype("int32"), device=tvm.cuda(0)) indices = tvm.nd.array(indices.cpu().numpy().astype("int32"), device=tvm.cuda(0)) @@ -158,16 +156,22 @@ def msg_func(edges): if epoch >= cold_start: accum += timer.time - print("sparse-tir:\t\t", accum / (total - cold_start)) + print("sparse-tir:\t\t {}ms".format(accum / (total - cold_start))) + + if y_dgl is not None: + tvm.testing.assert_allclose(y_dgl.view(-1).cpu().numpy(), Y.numpy(), rtol=1e-5) if __name__ == "__main__": - dataset = AIFBDataset() - g = dataset[0] - type_pointers = prepare_hetero_graph_simplified(g) - g = prepare_graph(dgl.to_homogeneous(g)) - g.ntype_pointer = type_pointers['ntype_node_pointer'] - g.etype_pointer = type_pointers['etype_edge_pointer'] - g.num_ntypes = max(g.ndata[dgl.NTYPE]).item() + 1 - g.num_rels = max(g.edata[dgl.ETYPE]).item() + 1 - test_rgcn(g) + for feat_size in [4, 8, 16, 32, 64]: + for name in ['aifb', 'mutag', 'bgs', 'am']: + print('dataset {}:'.format(name)) + dataset = get_dataset_by_name(name) + g = dataset[0] + type_pointers = prepare_hetero_graph_simplified(g) + g = prepare_graph(dgl.to_homogeneous(g)) + g.ntype_pointer = type_pointers['ntype_node_pointer'] + g.etype_pointer = type_pointers['etype_edge_pointer'] + g.num_ntypes = max(g.ndata[dgl.NTYPE]).item() + 1 + g.num_rels = max(g.edata[dgl.ETYPE]).item() + 1 + test_rgcn(g, feat_size)