Skip to content

Commit

Permalink
Add is_floating_point() test and better type support in `verify_mod…
Browse files Browse the repository at this point in the history
…el_vm()` (apache#7134)

* Add div_ and is_floating_point operators

* Add handling of exprs to op, update tests

* add test + supporting functions

* Revert whitespace changes

* Properly assign dtype to random integers

* Reformat with black

* Switched default dtype logic, removed extra line
  • Loading branch information
TylerADavis authored and Tushar Dey committed Jan 20, 2021
1 parent 3ce88a6 commit 80e92f0
Showing 1 changed file with 80 additions and 5 deletions.
85 changes: 80 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,9 +1889,10 @@ def _get_default_vm_targets():
return [tgt for (tgt, _) in tvm.testing.enabled_targets()]


def verify_script_model(pt_model, ishapes, targets):
def verify_script_model(pt_model, ishapes, targets, idtype=None):
script_module = torch.jit.script(pt_model)
verify_model_vm(script_module, ishapes, targets=targets)

verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets)


def verify_trace_model(pt_model, idata, targets):
Expand All @@ -1900,10 +1901,60 @@ def verify_trace_model(pt_model, idata, targets):
verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)


def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None, targets=["llvm"]):
def convert_pt_to_tvm_type(idtype):
""" Accepts a pytorch dtype and returns string TVM dtype."""
# TVM does not support PyTorch complex dtypes
if idtype == torch.float64:
curr_dtype = "float64"
elif idtype == torch.float32:
curr_dtype = "float32"
elif idtype == torch.float16:
curr_dtype = "float16"
elif idtype == torch.bfloat16:
curr_dtype = "bfloat16"
elif idtype == torch.int64:
curr_dtype = "int64"
elif idtype == torch.int32:
curr_dtype = "int32"
elif idtype == torch.int16:
curr_dtype = "int16"
elif idtype == torch.int8:
curr_dtype = "int8"
elif idtype == torch.uint8:
curr_dtype = "uint8"
elif idtype == torch.bool:
curr_dtype = "bool"
else:
raise NotImplementedError("Unsupported dtype: {}".format(idtype))
return curr_dtype


def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llvm"]):
if not idtype:
idtype = torch.float

input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
input_data = idata if idata else [torch.randn(shape, dtype=idtype) for shape in ishapes]
tvm_dtype = convert_pt_to_tvm_type(idtype)
input_dtypes = [tvm_dtype] * len(input_names)
input_shapes = list(zip(input_names, list(zip(ishapes, input_dtypes))))

if idata:
input_data = idata
# If no input_data provided, generate random data of specified dtype
else:
if idtype == torch.bool:
input_data = [
torch.Tensor.bool(torch.randint(low=0, high=2, size=shape)) for shape in ishapes
]
# Torch dtype can be float, complex, int, or Bool. Complex not supported, so if not float or Bool,
# dtype must be int!
elif not idtype.is_floating_point:
input_data = [
torch.randint(low=0, high=10, size=shape, dtype=idtype) for shape in ishapes
]
else:
input_data = [torch.randn(shape, dtype=idtype) for shape in ishapes]

# Compile via VM
mod, params = relay.frontend.from_pytorch(input_model, input_shapes)

Expand Down Expand Up @@ -2950,6 +3001,29 @@ def forward(self, *args):
)


@tvm.testing.uses_gpu
def test_forward_is_floating_point():
torch.set_grad_enabled(False)

class IsFloatingPoint(Module):
def forward(self, arg):
# `torch.jit.trace` cannot accept something that outputs
# a Bool, so `torch.jit.script` will be used instead
return torch.is_floating_point(arg)

targets = _get_default_vm_targets()
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float64)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float32)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float16)
# todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is implemented
# verify_script_model(IsFloatingPoint(), [(1,1)], targets, idtype=torch.bfloat16)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int64)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int32)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int16)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int8)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.uint8)


@tvm.testing.uses_gpu
def test_forward_traced_function():
def fn(t1, t2):
Expand Down Expand Up @@ -3425,6 +3499,7 @@ def test_fn(x, weights=None):
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
test_forward_is_floating_point()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand Down

0 comments on commit 80e92f0

Please sign in to comment.