Skip to content

Commit

Permalink
[PYTORCH]Tensor creation ops support (apache#5347)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and dpankratz committed Apr 24, 2020
1 parent db65d9b commit fa429e3
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 8 deletions.
110 changes: 102 additions & 8 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,25 @@ def _impl(inputs, input_types):
msg = "Data type %s could not be parsed in ones op" % (type(data))
raise AssertionError(msg)

dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))

return _op.full(_expr.const(1), shape, dtype=dtype)
return _impl

def _ones_like():
def _impl(inputs, input_types):
data = inputs[0]
out = _op.ones_like(data)

# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)

return out
return _impl


def _zeros():
def _impl(inputs, input_types):
data = inputs[0]
Expand All @@ -369,12 +382,88 @@ def _impl(inputs, input_types):
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)

dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))

return _op.full(_expr.const(0), shape, dtype=dtype)
return _impl


def _zeros_like():
def _impl(inputs, input_types):
data = inputs[0]
out = _op.zeros_like(data)

# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)

return out
return _impl


def _full():
def _impl(inputs, input_types):
data = inputs[0]

fill_value = inputs[1]
import torch
if isinstance(data, _expr.Expr):
shape = _infer_shape(data)
elif isinstance(data, list):
shape = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
shape = data.shape
else:
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)

dtype = _convert_data_type(_convert_dtype_value(inputs[2]))

return _op.full(_expr.const(fill_value), shape, dtype=dtype)
return _impl

def _full_like():
def _impl(inputs, input_types):
data = inputs[0]
fill_value = inputs[1]

out = _op.full_like(data, _expr.const(fill_value))

# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)

return out
return _impl


def _linspace():
def _impl(inputs, input_types):
start = inputs[0]
stop = inputs[1]
step = inputs[2]

# Find the spacing between values as step
if step != 1:
step = (stop - start) / (step - 1)
stop = stop + step
else:
stop = start + step

dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)

return _op.transform.arange(start=start,
stop=stop,
step=step,
dtype=_convert_data_type(dtype))
return _impl


def _relu():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1497,7 +1586,12 @@ def _get_convert_map(prelude):
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::ones_like" : _ones_like(),
"aten::zeros" : _zeros(),
"aten::zeros_like" : _zeros_like(),
"aten::full" : _full(),
"aten::full_like" : _full_like(),
"aten::linspace" : _linspace(),
"aten::reciprocal" : _reciprocal(),
"aten::repeat" : _repeat(),
"aten::repeat_interleave" : _repeat_interleave(),
Expand Down
145 changes: 145 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,144 @@ def forward(self, *args):
verify_model(Round1().float().eval(), input_data=input_data)


def test_forward_ones():
torch.set_grad_enabled(False)

class Ones1(Module):
def forward(self, *args):
return torch.ones(2,3)

verify_model(Ones1().float().eval(), input_data=[])


def test_forward_ones_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class OnesLike1(Module):
def forward(self, *args):
return torch.ones_like(args[0])

class OnesLike2(Module):
def forward(self, *args):
return torch.ones_like(args[0], dtype=torch.int8)

class OnesLike3(Module):
def forward(self, *args):
return torch.ones_like(args[0], dtype=torch.float)

input_data = torch.rand(input_shape).float()
verify_model(OnesLike1().float().eval(), input_data=input_data)
verify_model(OnesLike2().float().eval(), input_data=input_data)
verify_model(OnesLike3().float().eval(), input_data=input_data)


def test_forward_zeros():
torch.set_grad_enabled(False)

class Zeros1(Module):
def forward(self, *args):
return torch.zeros(2,3)

verify_model(Zeros1().float().eval(), input_data=[])


def test_forward_zeros_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class ZerosLike1(Module):
def forward(self, *args):
return torch.zeros_like(args[0])

class ZerosLike2(Module):
def forward(self, *args):
return torch.zeros_like(args[0], dtype=torch.int32)

class ZerosLike3(Module):
def forward(self, *args):
return torch.zeros_like(args[0], dtype=torch.float)

input_data = torch.rand(input_shape).float()
verify_model(ZerosLike1().float().eval(), input_data=input_data)
verify_model(ZerosLike2().float().eval(), input_data=input_data)
verify_model(ZerosLike3().float().eval(), input_data=input_data)


def test_forward_full():
torch.set_grad_enabled(False)

class Full1(Module):
def forward(self, *args):
return torch.full((2,3), 3.14)

class Full2(Module):
def forward(self, *args):
return torch.full((1, 2,3), 1.0, dtype=torch.int32)

verify_model(Full1().float().eval(), input_data=[])
verify_model(Full2().float().eval(), input_data=[])


def test_forward_full_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class FullLike1(Module):
def forward(self, *args):
return torch.full_like(args[0], 3.14)

class FullLike2(Module):
def forward(self, *args):
return torch.full_like(args[0], 22.22, dtype=torch.int32)

class FullLike3(Module):
def forward(self, *args):
return torch.full_like(args[0], 1.4, dtype=torch.float)

input_data = torch.rand(input_shape).float()
verify_model(FullLike1().float().eval(), input_data=input_data)
verify_model(FullLike2().float().eval(), input_data=input_data)
verify_model(FullLike3().float().eval(), input_data=input_data)

def test_forward_linspace():
torch.set_grad_enabled(False)

class Linspace1(Module):
def forward(self, *args):
return torch.linspace(5, 10)
class Linspace2(Module):
def forward(self, *args):
return torch.linspace(-10, 10, steps=5)
class Linspace3(Module):
def forward(self, *args):
return torch.linspace(start=-10, end=10, steps=5)
class Linspace4(Module):
def forward(self, *args):
return torch.linspace(start=-10, end=10, steps=1)
class Linspace5(Module):
def forward(self, *args):
return torch.linspace(1, 2, 1, dtype=torch.int32)
class Linspace6(Module):
def forward(self, *args):
return torch.linspace(start=1, end=6, steps=2)
class Linspace7(Module):
def forward(self, *args):
return torch.linspace(1, 4, dtype=torch.float32)
class Linspace8(Module):
def forward(self, *args):
return torch.linspace(1, 2, 1, dtype=torch.int16)

verify_model(Linspace1().float().eval())
verify_model(Linspace2().float().eval())
verify_model(Linspace3().float().eval())
verify_model(Linspace4().float().eval())
verify_model(Linspace5().float().eval())
verify_model(Linspace6().float().eval())
verify_model(Linspace7().float().eval())
verify_model(Linspace8().float().eval())


def test_forward_take():
torch.set_grad_enabled(False)
class Take1(Module):
Expand Down Expand Up @@ -1759,6 +1897,13 @@ def forward(self, *args):
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
test_forward_ones()
test_forward_ones_like()
test_forward_zeros()
test_forward_zeros_like()
test_forward_full()
test_forward_full_like()
test_forward_linspace()
test_forward_arange()
test_forward_chunk()
test_forward_split()
Expand Down

0 comments on commit fa429e3

Please sign in to comment.