Skip to content

Commit

Permalink
[tvmc] Introduce 'run' subcommand (part 4/4) (#6578)
Browse files Browse the repository at this point in the history
* [tvmc] Introduce 'run' subcommand (part 4/4)

 * Add 'tvmc run' subcommand to execute compiled modules
 * Include options to locally or remotelly using RPC
 * Include support to cpu and gpu devices


Co-authored-by: Marcus Shawcroft <[email protected]>
Co-authored-by: Matthew Barrett <[email protected]>

* adjust based on code review comments

* make test fixture to safely skip environments without tflite

* make --help option more clear

* improve error message to show expected inputs

* code-review adjusts

* update doc-string to default zeros->random

Co-authored-by: Marcus Shawcroft <[email protected]>
Co-authored-by: Matthew Barrett <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2020
1 parent b754bec commit 5db80f0
Show file tree
Hide file tree
Showing 7 changed files with 667 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/tvm/driver/tvmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from . import autotuner
from . import compiler
from . import runner
35 changes: 35 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
import os.path

from urllib.parse import urlparse

import tvm

from tvm import relay
Expand Down Expand Up @@ -102,3 +104,36 @@ def target_from_cli(target):
logger.debug("creating target from input: %s", target)

return tvm.target.Target(target)


def tracker_host_port_from_cli(rpc_tracker_str):
"""Extract hostname and (optional) port from strings
like "1.2.3.4:9090" or "4.3.2.1".
Used as a helper function to cover --rpc-tracker
command line argument, in different subcommands.
Parameters
----------
rpc_tracker_str : str
hostname (or IP address) and port of the RPC tracker,
in the format 'hostname[:port]'.
Returns
-------
rpc_hostname : str or None
hostname or IP address, extracted from input.
rpc_port : int or None
port number extracted from input (9090 default).
"""

rpc_hostname = rpc_port = None

if rpc_tracker_str:
parsed_url = urlparse("//%s" % rpc_tracker_str)
rpc_hostname = parsed_url.hostname
rpc_port = parsed_url.port or 9090
logger.info("RPC tracker hostname: %s", rpc_hostname)
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def compile_model(
target_host = target_host or ""

if tuning_records and os.path.exists(tuning_records):
# TODO (@leandron) a new PR will introduce the 'tune' subcommand
# the is used to generate the tuning records file
logger.debug("tuning records file provided: %s", tuning_records)
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(opt_level=3):
Expand All @@ -212,6 +210,8 @@ def compile_model(
source = str(mod) if source_type == "relay" else lib.get_source(source_type)
dumps[source_type] = source

# TODO we need to update this return to use the updated graph module APIs
# as these getter functions will be deprecated in the next release (@leandron)
return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps


Expand Down
Loading

0 comments on commit 5db80f0

Please sign in to comment.