Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow support for DIGITS #1714

Merged
merged 5 commits into from
Jul 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Temporary files
*.swp
*~
.DS_Store
TAGS

# Compiled / optimized files
Expand All @@ -16,4 +17,16 @@ TAGS
/build/
/dist/
*.egg-info/

#Intellij files
.idea/

#vscode
.vscode/

#.project
.project
/.project

#.tb
.tb/
2 changes: 1 addition & 1 deletion .gjslintrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
--max_line_length=120
--exclude_directories=3rdparty
--exclude_directories=3rdparty,tb
--disable=0121,0220
14 changes: 8 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ python: 2.7

env:
global:
- OPENBLAS_ROOT=~/openblas
- PROTOBUF_ROOT=~/protobuf
- CAFFE_ROOT=~/caffe
- TORCH_ROOT=~/torch
- OMP_NUM_THREADS=1
- OPENBLAS_MAIN_FREE=1
- secure: "WSqrE+PQm76DdoRLRGKTK6fRWfXZjIb0BWCZm3IgHgFO7OE6fcK2tBnpDNNw4XQjmo27FFWlEhxN32g18P84n5PvErHaH65IuS9Nv6FkLlPXZlVqGNxbPmEA4oTkD/6Y6kZyZWZtLh2+/1ijuzQAPnIy/4BEuL8pdO+PsoJ9hYM="
matrix:
- DIGITS_TEST_FRAMEWORK=caffe CAFFE_FORK=NVIDIA
- DIGITS_TEST_FRAMEWORK=caffe CAFFE_FORK=BVLC
- DIGITS_TEST_FRAMEWORK=torch
- DIGITS_TEST_FRAMEWORK=tensorflow
- DIGITS_TEST_FRAMEWORK=none
- DIGITS_TEST_FRAMEWORK=none WITH_PLUGINS=false

Expand Down Expand Up @@ -73,7 +75,6 @@ matrix:
cache:
apt: true
directories:
- $OPENBLAS_ROOT
- $PROTOBUF_ROOT
- $CAFFE_ROOT
- $TORCH_ROOT
Expand All @@ -85,6 +86,7 @@ addons:
- cmake
- cython
- git
- gfortran
- graphviz
- libboost-filesystem-dev
- libboost-python-dev
Expand All @@ -95,6 +97,7 @@ addons:
- libhdf5-serial-dev
- libleveldb-dev
- liblmdb-dev
- libopenblas-dev
- libopencv-dev
- libsnappy-dev
- python-dev
Expand Down Expand Up @@ -125,15 +128,14 @@ before_install:
install:
- mkdir -p ~/.config/matplotlib
- echo "backend:agg" > ~/.config/matplotlib/matplotlibrc
- ./scripts/travis/install-openblas.sh $OPENBLAS_ROOT
- ./scripts/travis/install-protobuf.sh $PROTOBUF_ROOT
- ./scripts/travis/install-caffe.sh $CAFFE_ROOT
- if [ "$DIGITS_TEST_FRAMEWORK" == "torch" ]; then ./scripts/travis/install-torch.sh $TORCH_ROOT; else unset TORCH_ROOT; fi
- pip install -r ./requirements.txt
- if [ "$DIGITS_TEST_FRAMEWORK" == "torch" ]; then travis_wait ./scripts/travis/install-torch.sh $TORCH_ROOT; else unset TORCH_ROOT; fi
- pip install -r ./requirements.txt --force-reinstall
- if [ "$DIGITS_TEST_FRAMEWORK" == "tensorflow" ]; then travis_wait ./scripts/travis/install-tensorflow.sh; fi
- pip install -r ./requirements_test.txt
- pip install -e .
- if [ "$WITH_PLUGINS" != "false" ]; then find ./plugins/*/* -maxdepth 0 -type d | xargs -n1 pip install -e; fi

script:
- ./digits-test -v

1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ recursive-include digits/templates *
recursive-include digits/static *
recursive-include digits/standard-networks *
recursive-include digits/tools/torch *
recursive-include digits/tools/tensorflow *
recursive-include digits/extensions *.css
recursive-include digits/extensions *.html
recursive-include digits/extensions *.js
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[![Build Status](https://travis-ci.org/NVIDIA/DIGITS.svg?branch=master)](https://travis-ci.org/NVIDIA/DIGITS)

DIGITS (the **D**eep Learning **G**PU **T**raining **S**ystem) is a webapp for training deep learning models.
The currently supported frameworks are: Caffe, Torch, and Tensorflow.

# Installation

Expand All @@ -18,6 +19,7 @@ Once you have installed DIGITS, visit [docs/GettingStarted.md](docs/GettingStart

Then, take a look at some of the other documentation at [docs/](docs/) and [examples/](examples/):

* [Getting started with TensorFlow](docs/GettingStartedTensorflow.md)
* [Getting started with Torch](docs/GettingStartedTorch.md)
* [Fine-tune a pretrained model](examples/fine-tuning/README.md)
* [Train an autoencoder network](examples/autoencoder/README.md)
Expand Down
4 changes: 2 additions & 2 deletions digits-lint
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ set -e

echo "=== Checking for Python lint ..."
if which flake8 >/dev/null 2>&1; then
python2 `which flake8` .
python2 `which flake8` --exclude ./digits/jobs .
else
python2 -m flake8 .
python2 -m flake8 --exclude ./digits/jobs .
fi

echo "=== Checking for JavaScript lint ..."
Expand Down
1 change: 1 addition & 0 deletions digits/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
torch,
server_name,
store_option,
tensorflow,
)


Expand Down
28 changes: 28 additions & 0 deletions digits/config/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
from . import option_list


def test_tf_import():
"""
Tests if tensorflow can be imported, returns if it went okay and optional error.
"""
try:
import tensorflow # noqa
return True
except ImportError:
return False

tf_enabled = test_tf_import()

if not tf_enabled:
print('Tensorflow support disabled.')

if tf_enabled:
option_list['tensorflow'] = {
'enabled': True
}
else:
option_list['tensorflow'] = {
'enabled': False
}
4 changes: 3 additions & 1 deletion digits/dataset/images/classification/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class ImageClassificationDatasetForm(ImageDatasetForm):
backend = wtforms.SelectField('DB backend',
choices=[
('lmdb', 'LMDB'),
('hdf5', 'HDF5'),
('hdf5', 'HDF5')
],
default='lmdb',
)

def validate_backend(form, field):
if field.data == 'lmdb':
form.compression.data = 'none'
elif field.data == 'tfrecords':
form.compression.data = 'none'
elif field.data == 'hdf5':
form.encoding.data = 'none'

Expand Down
3 changes: 2 additions & 1 deletion digits/dataset/tasks/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def __init__(self, input_file, db_name, backend, image_dims, **kwargs):
self.input_file = input_file
self.db_name = db_name
self.backend = backend
if backend == 'hdf5':
if backend == 'hdf5' or backend == 'tfrecords':
# the list of hdf5 files is stored in a textfile
# tfrecords can be sharded as well
self.textfile = os.path.join(self.db_name, 'list.txt')
self.image_dims = image_dims
if image_dims[2] == 3:
Expand Down
29 changes: 29 additions & 0 deletions digits/dataset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import images as dataset_images
from . import generic
from digits import extensions
from digits.utils.routing import job_from_request, request_wants_json
from digits.webapp import scheduler

Expand Down Expand Up @@ -54,3 +55,31 @@ def summary():
return generic.views.summary(job)
else:
raise werkzeug.exceptions.BadRequest('Invalid job type')


@blueprint.route('/inference-form/<extension_id>/<job_id>', methods=['GET'])
def inference_form(extension_id, job_id):
"""
Returns a rendering of an inference form
"""
inference_form_html = ""

if extension_id != "all-default":
extension_class = extensions.data.get_extension(extension_id)
if not extension_class:
raise RuntimeError("Unable to find data extension with ID=%s"
% job_id.dataset.extension_id)
job = scheduler.get_job(job_id)
if hasattr(job, 'extension_userdata'):
extension_userdata = job.extension_userdata
else:
extension_userdata = {}
extension_userdata.update({'is_inference_db': True})
extension = extension_class(**extension_userdata)

form = extension.get_inference_form()
if form:
template, context = extension.get_inference_template(form)
inference_form_html = flask.render_template_string(template, **context)

return inference_form_html
12 changes: 12 additions & 0 deletions digits/extensions/view/imageOutput/config_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,20 @@
{{ form.channel_order(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.data_order])}}">
{{ form.data_order.label }}
{{ form.data_order.tooltip }}
{{ form.data_order(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.pixel_conversion])}}">
{{ form.pixel_conversion.label }}
{{ form.pixel_conversion.tooltip }}
{{ form.pixel_conversion(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.show_input])}}">
{{ form.show_input.label }}
{{ form.show_input.tooltip }}
{{ form.show_input(class='form-control') }}
</div>
22 changes: 22 additions & 0 deletions digits/extensions/view/imageOutput/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ class ConfigForm(Form):
'is ignored in the case of a grayscale image)'
)

data_order = utils.forms.SelectField(
'Data order',
choices=[
('chw', 'CHW'),
('hwc', 'HWC'),
],
default='chw',
tooltip="Set the order of the data. For Caffe and Torch models this "
"is often NCHW, for Tensorflow it's NHWC."
"N=Batch Size, W=Width, H=Height, C=Channels"
)

pixel_conversion = utils.forms.SelectField(
'Pixel conversion',
choices=[
Expand All @@ -33,3 +45,13 @@ class ConfigForm(Form):
tooltip='Select method to convert pixel values to the target bit '
'range'
)

show_input = utils.forms.SelectField(
'Show input as image',
choices=[
('yes', 'Yes'),
('no', 'No'),
],
default='no',
tooltip='Show input as image'
)
24 changes: 21 additions & 3 deletions digits/extensions/view/imageOutput/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self, dataset, **kwargs):

# view options
self.channel_order = kwargs['channel_order'].upper()
self.data_order = kwargs['data_order'].upper()
self.normalize = (kwargs['pixel_conversion'] == 'normalize')
self.show_input = (kwargs['show_input'] == 'yes')

@staticmethod
def get_config_form():
Expand Down Expand Up @@ -69,15 +71,31 @@ def get_view_template(self, data):
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'image': digits.utils.image.embed_image_html(data)}
return self.view_template, {'image_input': digits.utils.image.embed_image_html(data[0]),
'image_output': digits.utils.image.embed_image_html(data[1])}

@override
def process_data(self, input_id, input_data, output_data):
"""
Process one inference and return data to visualize
"""
# assume the only output is a CHW image
data = output_data[output_data.keys()[0]].astype('float32')

if self.show_input:
data_input = input_data.astype('float32')
image_input = self.process_image(self.data_order, data_input)
else:
image_input = None

data_output = output_data[output_data.keys()[0]].astype('float32')
image_output = self.process_image(self.data_order, data_output)

return [image_input, image_output]

def process_image(self, data_order, data):
if data_order == 'HWC':
data = (data.transpose((2, 0, 1)))

# assume CHW at this point
channels = data.shape[0]
if channels == 3 and self.channel_order == 'BGR':
data = data[[2, 1, 0], ...] # BGR to RGB
Expand Down
5 changes: 4 additions & 1 deletion digits/extensions/view/imageOutput/view_template.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{# Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved. #}

<img src="{{image}}" style="max-width:100%;" />
{% if image_input %}
<img src="{{image_input}}" style="max-width:100%;" />
{% endif %}
<img src="{{image_output}}" style="max-width:100%;" />
7 changes: 7 additions & 0 deletions digits/extensions/view/rawData/header_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

{% from "helper.html" import print_flashes %}
{% from "helper.html" import print_errors %}
{% from "helper.html" import mark_errors %}

{{data}}
9 changes: 9 additions & 0 deletions digits/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
'TorchFramework',
]

if config_value('tensorflow')['enabled']:
from .tensorflow_framework import TensorflowFramework
__all__.append('TensorflowFramework')

#
# create framework instances
#

# torch is optional
torch = TorchFramework() if config_value('torch')['enabled'] else None

# tensorflow is optional
tensorflow = TensorflowFramework() if config_value('tensorflow')['enabled'] else None

# caffe is mandatory
caffe = CaffeFramework()

Expand All @@ -35,6 +42,8 @@ def get_frameworks():
frameworks = [caffe]
if torch:
frameworks.append(torch)
if tensorflow:
frameworks.append(tensorflow)
return frameworks


Expand Down
Loading