Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unittest] Correctness test for benchmarking scripts #49

Merged
merged 1 commit into from
Jan 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 57 additions & 53 deletions tests/python/sparsetir/bench_rgcn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sys import base_exec_prefix
from termios import ECHOE
from dgl.heterograph import DGLHeteroGraph
import tvm
import tvm.testing
Expand All @@ -13,6 +15,19 @@
from sparse_tir_scripts import rgcn_forward


def get_dataset_by_name(name: str):
if name == 'aifb':
return AIFBDataset()
elif name == 'mutag':
return MUTAGDataset()
elif name == 'bgs':
return BGSDataset()
elif name == 'am':
return AMDataset()
else:
raise KeyError("Unknown dataset {}.".format(name))


class TorchOpTimer(object):
def __enter__(self):
self.start_event = th.cuda.Event(enable_timing=True)
Expand All @@ -23,10 +38,10 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
self.end_event.record()
th.cuda.synchronize() # Wait for the events to be recorded!
self.time = self.start_event.elapsed_time(self.end_event) / 1e3
self.time = self.start_event.elapsed_time(self.end_event)


def prepare_hetero_graph_simplified(g):
def prepare_hetero_graph_simplified(g: dgl.DGLHeteroGraph):
ntype_pointer = np.cumsum(
[0] + [g.number_of_nodes(ntype) for ntype in g.ntypes])

Expand All @@ -38,7 +53,7 @@ def prepare_hetero_graph_simplified(g):
return{"ntype_node_pointer": th.IntTensor(ntype_pointer).cuda(), "etype_edge_pointer": th.IntTensor(etype_pointer).cuda()}


def prepare_graph(g, ntype=None):
def prepare_graph(g: dgl.DGLHeteroGraph, ntype=None):
# Todo: long int kernel support, multiple devices
g.node_id = g.nodes(ntype).type(th.IntTensor).cuda()

Expand All @@ -65,16 +80,15 @@ def prepare_graph(g, ntype=None):
return g


def test_rgcn(g: DGLHeteroGraph):
feat_size = 16
def test_rgcn(g: DGLHeteroGraph, feat_size: int):
g = g.to(0)
feat = th.rand(g.num_src_nodes(), feat_size).to(0)
y = th.zeros(g.num_dst_nodes(), feat_size).to(0)
feat = th.rand(g.num_src_nodes(), feat_size).to(0) / 100
out = th.zeros(g.num_dst_nodes(), feat_size).to(0) / 100
weight = th.rand(g.num_rels, feat_size, feat_size).to(0)
indptr, indices, eid = g.adj_sparse(fmt='csc')
etype = g.edata[dgl.ETYPE][eid]

# naive
# dgl-bmm

cold_start = 3
total = 10
Expand All @@ -85,38 +99,24 @@ def msg_func(edges):
W = weight[edges.data[dgl.ETYPE]]
return {'msg': W @ h}

for epoch in range(10):
with TorchOpTimer() as timer:
g.srcdata['feat'] = feat.unsqueeze(-1)
g.update_all(
msg_func,
fn.sum('msg', 'y')
)
y = g.dstdata['y'].sum(1)
if epoch >= cold_start:
accum += timer.time

print("dgl low-mem:\t\t", accum / (total - cold_start))

# broadcast

cold_start = 3
total = 10
accum = 0

for epoch in range(10):
with TorchOpTimer() as timer:
g.srcdata['feat'] = feat.unsqueeze(-1)
g.edata['W'] = weight[g.edata[dgl.ETYPE]]
g.update_all(
fn.u_mul_e('feat', 'W', 'msg'),
fn.sum('msg', 'y')
)
y = g.dstdata['y'].sum(1)
if epoch >= cold_start:
accum += timer.time

print("dgl high-mem:\t\t", accum / (total - cold_start))
try:
for epoch in range(10):
with TorchOpTimer() as timer:
g.srcdata['feat'] = feat.unsqueeze(-1)
g.update_all(
msg_func,
fn.sum('msg', 'y')
)
y_dgl = g.dstdata['y'].squeeze(-1)
if epoch >= cold_start:
accum += timer.time
print("dgl-bmm:\t\t {}ms".format(accum / (total - cold_start)))
except RuntimeError as err:
print("dgl-bmm: OOM")
y_dgl = None
except BaseException as err:
print(err)
raise

# tir
mod = tvm.IRModule.from_expr(rgcn_forward)
Expand All @@ -137,14 +137,12 @@ def msg_func(edges):
sch.bind(i, "blockIdx.x")
sch.bind(f_out, "threadIdx.y")
sch.bind(f_in, "threadIdx.x")
# sch.cache_read(inner, 4, "shared")
f = tvm.build(sch.mod, target='cuda')
# print(f.imported_modules[0].get_source())

E = tvm.nd.array(etype.cpu().numpy().astype("int32"), device=tvm.cuda(0))
W = tvm.nd.array(weight.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0))
X = tvm.nd.array(feat.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0))
Y = tvm.nd.array(y.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0))
Y = tvm.nd.array(out.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0))
indptr = tvm.nd.array(indptr.cpu().numpy().astype("int32"), device=tvm.cuda(0))
indices = tvm.nd.array(indices.cpu().numpy().astype("int32"), device=tvm.cuda(0))

Expand All @@ -158,16 +156,22 @@ def msg_func(edges):
if epoch >= cold_start:
accum += timer.time

print("sparse-tir:\t\t", accum / (total - cold_start))
print("sparse-tir:\t\t {}ms".format(accum / (total - cold_start)))

if y_dgl is not None:
tvm.testing.assert_allclose(y_dgl.view(-1).cpu().numpy(), Y.numpy(), rtol=1e-5)


if __name__ == "__main__":
dataset = AIFBDataset()
g = dataset[0]
type_pointers = prepare_hetero_graph_simplified(g)
g = prepare_graph(dgl.to_homogeneous(g))
g.ntype_pointer = type_pointers['ntype_node_pointer']
g.etype_pointer = type_pointers['etype_edge_pointer']
g.num_ntypes = max(g.ndata[dgl.NTYPE]).item() + 1
g.num_rels = max(g.edata[dgl.ETYPE]).item() + 1
test_rgcn(g)
for feat_size in [4, 8, 16, 32, 64]:
for name in ['aifb', 'mutag', 'bgs', 'am']:
print('dataset {}:'.format(name))
dataset = get_dataset_by_name(name)
g = dataset[0]
type_pointers = prepare_hetero_graph_simplified(g)
g = prepare_graph(dgl.to_homogeneous(g))
g.ntype_pointer = type_pointers['ntype_node_pointer']
g.etype_pointer = type_pointers['etype_edge_pointer']
g.num_ntypes = max(g.ndata[dgl.NTYPE]).item() + 1
g.num_rels = max(g.edata[dgl.ETYPE]).item() + 1
test_rgcn(g, feat_size)