Skip to content

Commit

Permalink
[TUTORIAL] Move mobilenet to tutorial, fix precompute_prune (apache#35)
Browse files Browse the repository at this point in the history
* [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune

* Some language improvements
  • Loading branch information
tqchen committed May 29, 2018
1 parent 12fa914 commit 34d7428
Show file tree
Hide file tree
Showing 15 changed files with 271 additions and 135 deletions.
2 changes: 2 additions & 0 deletions nnvm/docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
doxygen
_build
gen_modules
tutorials
4 changes: 4 additions & 0 deletions nnvm/docs/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
The documentation of nnvm is generated with recommonmark and sphinx.

- pip install sphinx>=1.5.5 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark
- Build tvm first in the root folder.
30 changes: 28 additions & 2 deletions nnvm/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os, subprocess
import shlex
import recommonmark
import sphinx_gallery
from recommonmark.parser import CommonMarkParser
from recommonmark.transform import AutoStructify

Expand Down Expand Up @@ -50,7 +51,8 @@
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinx.ext.mathjax'
'sphinx.ext.mathjax',
'sphinx_gallery.gen_gallery',
]

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -129,7 +131,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_static_path = ['_static']

# Output file base name for HTML help builder.
htmlhelp_basename = project + 'doc'
Expand Down Expand Up @@ -164,9 +166,17 @@ def run_doxygen(folder):
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
'tvm': ('http://docs.tvmlang.org/', None),
}


from sphinx_gallery.sorting import ExplicitOrder

examples_dirs = ['../tutorials/']
gallery_dirs = ['tutorials']

subsection_order = ExplicitOrder([])

def generate_doxygen_xml(app):
"""Run the doxygen make commands if we're on the ReadTheDocs server"""
run_doxygen('..')
Expand All @@ -180,3 +190,19 @@ def setup(app):
'auto_doc_ref': True
}, True)
app.add_transform(AutoStructify)


sphinx_gallery_conf = {
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('tvm', 'nnvm', 'numpy'),
'reference_url': {
'nnvm': None,
'tvm': 'http://docs.tvmlang.org',
'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1'},
'examples_dirs': examples_dirs,
'gallery_dirs': gallery_dirs,
'subsection_order': subsection_order,
'find_mayavi_figures': False,
'filename_pattern': '.py',
'expected_failing_examples': []
}
4 changes: 2 additions & 2 deletions nnvm/docs/dev/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NNVM Design Note
================
Design Note
===========

In this part of documentation, we share the rationale for the specific choices made when designing NNVM.

Expand Down
1 change: 1 addition & 0 deletions nnvm/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ Contents

self
top
tutorials/index
dev/index
4 changes: 2 additions & 2 deletions nnvm/docs/top.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
NNVM Core Tensor Operators
==========================
Core Tensor Operators
=====================

This page contains the list of core tensor operator primitives re-defined in NNVM.
The core tensor operator primitives(``nnvm.top``) covers typical workloads in deep learning.
Expand Down
117 changes: 0 additions & 117 deletions nnvm/example/mobilenet_inference_gpu.py

This file was deleted.

5 changes: 5 additions & 0 deletions nnvm/examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
NNVM Examples
=============
This folder contains example snippets of running NNVM Compilation.

- See also [Tutorials](tutorials) for tutorials with detailed explainations.
4 changes: 3 additions & 1 deletion nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utilities for testcase"""
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs

from .config import ctx_list
from . import mobilenet
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/testing/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Configuration about tests"""
from __future__ import absolute_import as _abs

import os
import tvm

Expand Down
125 changes: 125 additions & 0 deletions nnvm/python/nnvm/testing/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Helper utility to get mobilenet workload for testing."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs

import numpy as np
import tvm
from .. compiler import graph_util
from .. import graph
from .. import symbol as sym

def conv_block(data, name, channels,
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
epsilon=1e-5):
"""Helper function to construct conv-bn-relu"""
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels,
kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False,
layout="NCHW", name=name + "_conv")
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn")
act = sym.relu(data=bn, name=name + "_relu")
return act

def separable_conv_block(data, name, depthwise_channels,
pointwise_channels, kernel_size=(3, 3),
downsample=False, padding=(1, 1),
epsilon=1e-5):
"""Helper function to get a separable conv block"""
if downsample:
strides = (2, 2)
else:
strides = (1, 1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels,
groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1")
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
act1 = sym.relu(data=bn1, name=name + "_relu1")
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1),
padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2")
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2")
act2 = sym.relu(data=bn2, name=name + "_relu2")
return act2

def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
"""Function to construct a MobileNet"""
data = sym.Variable("data")
body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2))
body = separable_conv_block(body, "separable_conv_block_1",
int(32*alpha), int(64*alpha))
body = separable_conv_block(body, "separable_conv_block_2",
int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_3",
int(128*alpha), int(128*alpha))
body = separable_conv_block(body, "separable_conv_block_4",
int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_5",
int(256*alpha), int(256*alpha))
body = separable_conv_block(body, "separable_conv_block_6",
int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, "separable_conv_block_7",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_8",
int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, "separable_conv_block_%d" % i,
int(512*alpha), int(512*alpha))
body = separable_conv_block(body, "separable_conv_block_12",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_13",
int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name="pool")
flatten = sym.flatten(data=pool, name="flatten")
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc")
softmax = sym.softmax(data=fc, name="softmax")
return softmax


def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
for k, v in shape_dict.items():
if k == "data":
continue
# Specially generate non-negative parameters.
if k.endswith("gamma"):
init = np.random.uniform(0.9, 1, size=v)
elif k.endswith("var"):
init = np.random.uniform(0.9, 1, size=v)
else:
init = np.random.uniform(-0.1, 0.1, size=v)
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
return net, params
11 changes: 5 additions & 6 deletions nnvm/src/compiler/precompute_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
} else {
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (!e.node->is_variable() && pruned.count(e.node.get())) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
}
// TODO(ziheng): this pass now mutates the original graph structure
// This might not be a good thing, change to copy the structure instead
//
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
}
Expand All @@ -67,7 +67,6 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
output_names.reserve(entry_var.size());

for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name);
}
Expand Down
Loading

0 comments on commit 34d7428

Please sign in to comment.