From 34d74282ec0adce60eda4298b82e411e3dd17543 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 24 Sep 2017 14:17:06 -0700 Subject: [PATCH] [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune (#35) * [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune * Some language improvements --- nnvm/docs/.gitignore | 2 + nnvm/docs/README.txt | 4 + nnvm/docs/conf.py | 30 +++++- nnvm/docs/dev/index.rst | 4 +- nnvm/docs/index.rst | 1 + nnvm/docs/top.rst | 4 +- nnvm/example/mobilenet_inference_gpu.py | 117 -------------------- nnvm/examples/README.md | 5 + nnvm/python/nnvm/testing/__init__.py | 4 +- nnvm/python/nnvm/testing/config.py | 2 + nnvm/python/nnvm/testing/mobilenet.py | 125 ++++++++++++++++++++++ nnvm/src/compiler/precompute_prune.cc | 11 +- nnvm/tests/python/compiler/test_build.py | 12 ++- nnvm/tutorials/README.txt | 3 + nnvm/tutorials/mobilenet_inference_gpu.py | 82 ++++++++++++++ 15 files changed, 271 insertions(+), 135 deletions(-) create mode 100644 nnvm/docs/README.txt delete mode 100644 nnvm/example/mobilenet_inference_gpu.py create mode 100644 nnvm/examples/README.md create mode 100644 nnvm/python/nnvm/testing/mobilenet.py create mode 100644 nnvm/tutorials/README.txt create mode 100644 nnvm/tutorials/mobilenet_inference_gpu.py diff --git a/nnvm/docs/.gitignore b/nnvm/docs/.gitignore index 024fbfbe7bd0..d5d021127425 100644 --- a/nnvm/docs/.gitignore +++ b/nnvm/docs/.gitignore @@ -1,2 +1,4 @@ doxygen _build +gen_modules +tutorials diff --git a/nnvm/docs/README.txt b/nnvm/docs/README.txt new file mode 100644 index 000000000000..8b8c750822be --- /dev/null +++ b/nnvm/docs/README.txt @@ -0,0 +1,4 @@ +The documentation of nnvm is generated with recommonmark and sphinx. + +- pip install sphinx>=1.5.5 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark +- Build tvm first in the root folder. diff --git a/nnvm/docs/conf.py b/nnvm/docs/conf.py index af089466e00a..5175167185cd 100644 --- a/nnvm/docs/conf.py +++ b/nnvm/docs/conf.py @@ -15,6 +15,7 @@ import os, subprocess import shlex import recommonmark +import sphinx_gallery from recommonmark.parser import CommonMarkParser from recommonmark.transform import AutoStructify @@ -50,7 +51,8 @@ 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.napoleon', - 'sphinx.ext.mathjax' + 'sphinx.ext.mathjax', + 'sphinx_gallery.gen_gallery', ] # Add any paths that contain templates here, relative to this directory. @@ -129,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +# html_static_path = ['_static'] # Output file base name for HTML help builder. htmlhelp_basename = project + 'doc' @@ -164,9 +166,17 @@ def run_doxygen(folder): 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference', None), 'matplotlib': ('http://matplotlib.org/', None), + 'tvm': ('http://docs.tvmlang.org/', None), } +from sphinx_gallery.sorting import ExplicitOrder + +examples_dirs = ['../tutorials/'] +gallery_dirs = ['tutorials'] + +subsection_order = ExplicitOrder([]) + def generate_doxygen_xml(app): """Run the doxygen make commands if we're on the ReadTheDocs server""" run_doxygen('..') @@ -180,3 +190,19 @@ def setup(app): 'auto_doc_ref': True }, True) app.add_transform(AutoStructify) + + +sphinx_gallery_conf = { + 'backreferences_dir': 'gen_modules/backreferences', + 'doc_module': ('tvm', 'nnvm', 'numpy'), +'reference_url': { + 'nnvm': None, + 'tvm': 'http://docs.tvmlang.org', + 'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1'}, + 'examples_dirs': examples_dirs, + 'gallery_dirs': gallery_dirs, + 'subsection_order': subsection_order, + 'find_mayavi_figures': False, + 'filename_pattern': '.py', + 'expected_failing_examples': [] +} diff --git a/nnvm/docs/dev/index.rst b/nnvm/docs/dev/index.rst index ecee6889d071..0647c9cce586 100644 --- a/nnvm/docs/dev/index.rst +++ b/nnvm/docs/dev/index.rst @@ -1,5 +1,5 @@ -NNVM Design Note -================ +Design Note +=========== In this part of documentation, we share the rationale for the specific choices made when designing NNVM. diff --git a/nnvm/docs/index.rst b/nnvm/docs/index.rst index 9011bacc9e5e..14db719029ea 100644 --- a/nnvm/docs/index.rst +++ b/nnvm/docs/index.rst @@ -10,4 +10,5 @@ Contents self top + tutorials/index dev/index diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst index adbafdc99b6b..89af46509d02 100644 --- a/nnvm/docs/top.rst +++ b/nnvm/docs/top.rst @@ -1,5 +1,5 @@ -NNVM Core Tensor Operators -========================== +Core Tensor Operators +===================== This page contains the list of core tensor operator primitives re-defined in NNVM. The core tensor operator primitives(``nnvm.top``) covers typical workloads in deep learning. diff --git a/nnvm/example/mobilenet_inference_gpu.py b/nnvm/example/mobilenet_inference_gpu.py deleted file mode 100644 index 4331ec4c0f02..000000000000 --- a/nnvm/example/mobilenet_inference_gpu.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Forward propagation of MobileNet on GPU.""" -import numpy as np -import time -import os - -import tvm -import topi -import nnvm.symbol as sym -import nnvm.compiler -import nnvm.runtime -from tvm.contrib import nvcc - -TASK="mobilenet" - -target = 'cuda' -ctx = tvm.gpu(0) - -@tvm.register_func -def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"]) - return ptx - -def write_code(code, fname): - with open(fname, "w") as f: - f.write(code) - -@tvm.register_func -def tvm_callback_cuda_postproc(code): - if not os.path.exists("perf"): - os.mkdir("perf") - write_code(code, "perf/%s_generated.cu" % TASK) - return code - -dtype = 'float32' -epsilon = 1e-10 + 1e-5 - -def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)): - # convolution + bn + relu - conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides, - padding=padding, use_bias=False, layout='NCHW', name=name + '_conv') - bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn') - act = sym.relu(data=bn, name=name + '_relu') - return act - -def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)): - if downsample: - strides = (2,2) - else: - strides = (1,1) - # depthwise convolution + bn + relu - conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides, - padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1') - bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1') - act1 = sym.relu(data=bn1, name=name + '_relu1') - # pointwise convolution + bn + relu - conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1), - padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2') - bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2') - act2 = sym.relu(data=bn2, name=name + '_relu2') - return act2 - -def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False): - data = sym.Variable("data") - body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2)) - body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha)) - body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True) - body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha)) - body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True) - body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha)) - body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True) - if is_shallow: - body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True) - body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha)) - else: - for i in range(7, 12): - body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha)) - body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True) - body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha)) - pool = sym.global_avg_pool2d(data=body, name='pool') - flatten = sym.flatten(data=pool, name='flatten') - fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc') - softmax = sym.softmax(data=fc, name='softmax') - return softmax - - -batch_size = 1 -num_classes = 1000 -image_shape = (3,224,224) -data_shape = (batch_size,) + image_shape -out_shape = (batch_size, num_classes) - -net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False) - -# build graph -with nnvm.compiler.build_config(opt_level=2): - graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape}) -# prepare params -params = {} -names = graph.index.input_names -shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names] -for i in range(len(names)): - params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx) -# create runtime module -module = nnvm.runtime.create(graph, lib, ctx) -# set input -module.set_input(**params) -# run -print("run") -module.run() -ctx.sync() -start = time.time() -for i in range(1000): - module.run() - ctx.sync() -print("average time cost of 1000 runs = %g ms" % ((time.time() - start))) -# get output -out = module.get_output(0, tvm.nd.empty(out_shape, dtype)) diff --git a/nnvm/examples/README.md b/nnvm/examples/README.md new file mode 100644 index 000000000000..123007b552ea --- /dev/null +++ b/nnvm/examples/README.md @@ -0,0 +1,5 @@ +NNVM Examples +============= +This folder contains example snippets of running NNVM Compilation. + +- See also [Tutorials](tutorials) for tutorials with detailed explainations. diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 6dd015d872ea..27aaad8de09f 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -1,3 +1,5 @@ -"""Utilities for testcase""" +"""Utilities for testing and benchmarks""" +from __future__ import absolute_import as _abs from .config import ctx_list +from . import mobilenet diff --git a/nnvm/python/nnvm/testing/config.py b/nnvm/python/nnvm/testing/config.py index a96e4b4ea8e1..0eab3e6b3389 100644 --- a/nnvm/python/nnvm/testing/config.py +++ b/nnvm/python/nnvm/testing/config.py @@ -1,4 +1,6 @@ """Configuration about tests""" +from __future__ import absolute_import as _abs + import os import tvm diff --git a/nnvm/python/nnvm/testing/mobilenet.py b/nnvm/python/nnvm/testing/mobilenet.py new file mode 100644 index 000000000000..4a08380312d5 --- /dev/null +++ b/nnvm/python/nnvm/testing/mobilenet.py @@ -0,0 +1,125 @@ +"""Helper utility to get mobilenet workload for testing.""" +# pylint: disable=invalid-name +from __future__ import absolute_import as _abs + +import numpy as np +import tvm +from .. compiler import graph_util +from .. import graph +from .. import symbol as sym + +def conv_block(data, name, channels, + kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), + epsilon=1e-5): + """Helper function to construct conv-bn-relu""" + # convolution + bn + relu + conv = sym.conv2d(data=data, channels=channels, + kernel_size=kernel_size, strides=strides, + padding=padding, use_bias=False, + layout="NCHW", name=name + "_conv") + bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn") + act = sym.relu(data=bn, name=name + "_relu") + return act + +def separable_conv_block(data, name, depthwise_channels, + pointwise_channels, kernel_size=(3, 3), + downsample=False, padding=(1, 1), + epsilon=1e-5): + """Helper function to get a separable conv block""" + if downsample: + strides = (2, 2) + else: + strides = (1, 1) + # depthwise convolution + bn + relu + conv1 = sym.conv2d(data=data, channels=depthwise_channels, + groups=depthwise_channels, kernel_size=kernel_size, strides=strides, + padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1") + bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1") + act1 = sym.relu(data=bn1, name=name + "_relu1") + # pointwise convolution + bn + relu + conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1), + padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2") + bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2") + act2 = sym.relu(data=bn2, name=name + "_relu2") + return act2 + +def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False): + """Function to construct a MobileNet""" + data = sym.Variable("data") + body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2)) + body = separable_conv_block(body, "separable_conv_block_1", + int(32*alpha), int(64*alpha)) + body = separable_conv_block(body, "separable_conv_block_2", + int(64*alpha), int(128*alpha), downsample=True) + body = separable_conv_block(body, "separable_conv_block_3", + int(128*alpha), int(128*alpha)) + body = separable_conv_block(body, "separable_conv_block_4", + int(128*alpha), int(256*alpha), downsample=True) + body = separable_conv_block(body, "separable_conv_block_5", + int(256*alpha), int(256*alpha)) + body = separable_conv_block(body, "separable_conv_block_6", + int(256*alpha), int(512*alpha), downsample=True) + if is_shallow: + body = separable_conv_block(body, "separable_conv_block_7", + int(512*alpha), int(1024*alpha), downsample=True) + body = separable_conv_block(body, "separable_conv_block_8", + int(1024*alpha), int(1024*alpha)) + else: + for i in range(7, 12): + body = separable_conv_block(body, "separable_conv_block_%d" % i, + int(512*alpha), int(512*alpha)) + body = separable_conv_block(body, "separable_conv_block_12", + int(512*alpha), int(1024*alpha), downsample=True) + body = separable_conv_block(body, "separable_conv_block_13", + int(1024*alpha), int(1024*alpha)) + pool = sym.global_avg_pool2d(data=body, name="pool") + flatten = sym.flatten(data=pool, name="flatten") + fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc") + softmax = sym.softmax(data=fc, name="softmax") + return softmax + + +def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"): + """Get benchmark workload for mobilenet + + Parameters + ---------- + batch_size : int + The batch size used in the model + + num_classes : int, optional + Number of claseses + + image_shape : tuple, optional + The input image shape + + dtype : str, optional + The data type + + Returns + ------- + net : nnvm.Symbol + The computational graph + + params : dict of str to NDArray + The parameters. + """ + image_shape = (3, 224, 224) + data_shape = (batch_size,) + image_shape + net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False) + params = {} + g = graph.create(net) + input_shapes, _ = graph_util.infer_shape(g, data=data_shape) + shape_dict = dict(zip(g.index.input_names, input_shapes)) + for k, v in shape_dict.items(): + if k == "data": + continue + # Specially generate non-negative parameters. + if k.endswith("gamma"): + init = np.random.uniform(0.9, 1, size=v) + elif k.endswith("var"): + init = np.random.uniform(0.9, 1, size=v) + else: + init = np.random.uniform(-0.1, 0.1, size=v) + params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0)) + return net, params diff --git a/nnvm/src/compiler/precompute_prune.cc b/nnvm/src/compiler/precompute_prune.cc index a0159757c398..a56a0398679c 100644 --- a/nnvm/src/compiler/precompute_prune.cc +++ b/nnvm/src/compiler/precompute_prune.cc @@ -44,17 +44,17 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { } else { // scan again to find edge nodes, skip variables for (auto& e : n->inputs) { - if (!e.node->is_variable() && pruned.count(e.node.get())) { + if (pruned.count(e.node.get())) { if (!entry_var.count(e)) { nnvm::NodePtr var = nnvm::Node::Create(); - var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index); + var->attrs.name = e.node->attrs.name; + if (e.node->num_outputs() != 1) { + var->attrs.name += "_output" + std::to_string(e.index); + } entry_var.emplace(e, var); CHECK(!unique_name.count(var->attrs.name)); unique_name.insert(var->attrs.name); } - // TODO(ziheng): this pass now mutates the original graph structure - // This might not be a good thing, change to copy the structure instead - // e = nnvm::NodeEntry{entry_var.at(e), 0, 0}; } } @@ -67,7 +67,6 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) { output_names.reserve(entry_var.size()); for (auto kv : entry_var) { - if (kv.first.node->is_variable()) continue; pre_graph.outputs.emplace_back(kv.first); output_names.emplace_back(kv.second->attrs.name); } diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index 379975d2d6a4..59220a7ca63e 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -55,26 +55,28 @@ def test_run(): def test_precompute_prune(): x = sym.Variable("x") + 1 + a = sym.Variable("a") y = sym.Variable("y") - z = y + x + z = y + x + a shape = (10, 10) dtype = tvm.float32 nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) + na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) - params = {"x": nx} + params = {"x": nx, "a": na} graph, lib, params = nnvm.compiler.build( z, "llvm", shape={"y": ny.shape}, params=params) - assert graph.index.num_nodes == 3 + assert graph.index.num_nodes == 4 m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) params["y"] = ny res = tvm.nd.empty(shape) m.run(**params) out = m.get_output(0, out=res) np.testing.assert_allclose( - res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy()) + res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy()) if __name__ == "__main__": + test_precompute_prune() test_compile() test_run() - test_precompute_prune() diff --git a/nnvm/tutorials/README.txt b/nnvm/tutorials/README.txt new file mode 100644 index 000000000000..72f772fa6feb --- /dev/null +++ b/nnvm/tutorials/README.txt @@ -0,0 +1,3 @@ +Tutorials +========= +This page contains the tutorials about NNVM. diff --git a/nnvm/tutorials/mobilenet_inference_gpu.py b/nnvm/tutorials/mobilenet_inference_gpu.py new file mode 100644 index 000000000000..9343316b3896 --- /dev/null +++ b/nnvm/tutorials/mobilenet_inference_gpu.py @@ -0,0 +1,82 @@ +""" +Compile MobileNet Inference on GPU +================================== +**Author**: `Yuwei Hu `_ + +This is an example of using NNVM to compile MobileNet model and deploy its inference on GPU. + +To begin with, we import nnvm(for compilation) and TVM(for deployment). +""" +import tvm +import nnvm.compiler +import nnvm.runtime +import nnvm.testing +from tvm.contrib import nvcc + +###################################################################### +# Register the NVCC Compiler Option +# --------------------------------- +# NNVM optimizes the graph and relies on TVM to generate fast +# GPU code, to get the maximum performance, we need to enable +# nvcc's compiler hook. This gives better performance than nvrtc mode. + +@tvm.register_func +def tvm_callback_cuda_compile(code): + ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + return ptx + +###################################################################### +# Prepare the Benchmark +# --------------------- +# We construct a standard imagenet inference benchmark. +# We use nnvm's testing utility to produce the model description and random parameters that so the example does not +# depend on a specific front-end framework. +# +# .. note:: +# +# In a typical workflow, we can get this pair from :any:`nnvm.frontend` +# +target = "cuda" +ctx = tvm.gpu(0) +batch_size = 1 +num_classes = 1000 +image_shape = (3, 224, 224) +data_shape = (batch_size,) + image_shape +out_shape = (batch_size, num_classes) +net, params = nnvm.testing.mobilenet.get_workload( + batch_size=1, image_shape=image_shape) + +###################################################################### +# Compile The Graph +# ----------------- +# NNVM needs two things to compile a deep learning model: +# +# - net which is the graph representation of the computation +# - params a dictionary of str to parameters. +# +# To compile the graph, we call the build function with the graph +# configuration and parameters. +# When parameters are provided, NNVM will pre-compute certain part of the graph if possible, +# the new parameter set returned as the third return value. + +graph, lib, params = nnvm.compiler.build( + net, target, shape={"data": data_shape}, params=params) + +###################################################################### +# Run the Compiled Module +# ----------------------- +# +# To deploy the module, we call :any:`nnvm.runtime.create` passing in the graph the lib and context. +# Thanks to TVM, we can deploy the compiled module to many platforms and languages. +# The deployment module is designed to contain minimum dependencies. +# This example runs on the same machine. + +module = nnvm.runtime.create(graph, lib, ctx) +# set input +module.set_input(**params) +# run +module.run() +# get output +out = module.get_output(0, tvm.nd.empty(out_shape)) +# Convert to numpy +out.asnumpy()