Skip to content

Commit

Permalink
Make remote db creation async op
Browse files Browse the repository at this point in the history
  • Loading branch information
ska278 committed May 15, 2024
1 parent c15619e commit 5772fcc
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 13 deletions.
3 changes: 2 additions & 1 deletion python/dgl/_sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def _gspmm(gidx, op, reduce_op, u, e):
v_shp = (gidx.num_nodes(dsttype),) + infer_broadcast_shape(
op, u_shp[1:], e_shp[1:]
)
v = F.zeros(v_shp, dtype, ctx)
#v = F.zeros(v_shp, dtype, ctx)
v = F.empty(v_shp, dtype, ctx)
use_cmp = reduce_op in ["max", "min"]
arg_u, arg_e = None, None
idtype = getattr(F, gidx.dtype)
Expand Down
1 change: 1 addition & 0 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def data_type_dict():
"""Returns a dictionary from data type string to the data type.
The dictionary should include at least:
bfloat8
bfloat16
float16
float32
Expand Down
4 changes: 3 additions & 1 deletion python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import scipy # Weird bug in new pytorch when import scipy after import torch
import torch as th
from torch.utils import dlpack

from ... import ndarray as nd
from ...function.base import TargetCode
from ...utils import version
Expand All @@ -18,6 +17,7 @@

def data_type_dict():
return {
"bfloat8": th.float8_e5m2,
"bfloat16": th.bfloat16,
"float16": th.float16,
"float32": th.float32,
Expand Down Expand Up @@ -432,6 +432,8 @@ def zerocopy_from_numpy(np_array):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
elif data.dtype ==th.float8_e5m2:
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))


Expand Down
15 changes: 4 additions & 11 deletions python/dgl/distgnn/iels.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,14 @@ def iels_init(self, pb, args):
print("sn_onid size: ", self.part_sn_onid.size())
print("sn_gnid size: ", self.part_sn_gnid.size())

if debug:
rr = self.create_remote_sn_db_commence()
assert th.equal(self.onid_map, pb.onid_map) == True
assert th.equal(self.pid_map, pb.pid_map) == True
else:
assert pb.onid_map.shape[0] == self.N
assert pb.pid_map.shape[0] == self.N
self.onid_map = pb.onid_map
self.pid_map = pb.pid_map
rr = self.create_remote_sn_db_commence()
assert th.equal(self.onid_map, pb.onid_map) == True
assert th.equal(self.pid_map, pb.pid_map) == True

## local feats db - orig node id to index in the feat table
self.create_local_sn_db(self.part_sn_onid, self.part_sn_gnid) ## at partition level
## database to know in which partition the remote ndoes are residing
if debug:
self.create_remote_sn_db(rr)
self.create_remote_sn_db(rr)

## remote embedding buffering data structures
self.enable_iec = args.enable_iec
Expand Down

0 comments on commit 5772fcc

Please sign in to comment.