Skip to content

Commit

Permalink
[TVMC][FIX] Compiler supports input with a slash (#8481)
Browse files Browse the repository at this point in the history
  • Loading branch information
leeexyz authored Jul 16, 2021
1 parent 5bb01ef commit c263f22
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def parse_shape_string(inputs_string):
"""

# Create a regex pattern that extracts each separate input mapping.
pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
pattern = r"(?:\w+\/)?\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
input_mappings = re.findall(pattern, inputs_string)
if not input_mappings:
raise argparse.ArgumentTypeError(
Expand Down
21 changes: 21 additions & 0 deletions tests/python/driver/tvmc/test_tvmc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def test_shape_parser():
shape_dict = tvmc.common.parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected

# Check that invalid pattern raises expected error.
shape_string = "input:[a,10]"
Expand All @@ -186,6 +191,22 @@ def test_shape_parser():
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with a invalid slash raises error.
shape_string = "gpu_0/data_0:5,10 /:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with a invalid slash raises error.
shape_string = "gpu_0/data_0:5,10 data/:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with a invalid slash raises error.
shape_string = "gpu_0/data_0:5,10 /data:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with invalid slashes raises error.
shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


def test_target_from_cli__error_duplicate():
Expand Down

0 comments on commit c263f22

Please sign in to comment.