From c263f22ffdd6cf85401dbc128a17710f3e36b768 Mon Sep 17 00:00:00 2001 From: Y Date: Fri, 16 Jul 2021 14:08:58 +0800 Subject: [PATCH] [TVMC][FIX] Compiler supports input with a slash (#8481) --- python/tvm/driver/tvmc/common.py | 2 +- tests/python/driver/tvmc/test_tvmc_common.py | 21 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index d3a62b508135..15c09753d46f 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -396,7 +396,7 @@ def parse_shape_string(inputs_string): """ # Create a regex pattern that extracts each separate input mapping. - pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + pattern = r"(?:\w+\/)?\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index cb6b82a32937..31fa688ad717 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -177,6 +177,11 @@ def test_shape_parser(): shape_dict = tvmc.common.parse_shape_string(shape_string) # Convert to strings to allow comparison with Any. assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected # Check that invalid pattern raises expected error. shape_string = "input:[a,10]" @@ -186,6 +191,22 @@ def test_shape_parser(): shape_string = "input:5,10 input2:10,10" with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 data/:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid slash raises error. + shape_string = "gpu_0/data_0:5,10 /data:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with invalid slashes raises error. + shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) def test_target_from_cli__error_duplicate():