Skip to content

Commit

Permalink
Gan support and implementation for Tensorflow by Greg (#6)
Browse files Browse the repository at this point in the history
* bAbI data plug-in

Add utils

Add inference form to bAbI dataset

Allow inference without answer

Allow unknown words in BaBI data plug-in

Fix bAbI plugin Lint errors

* Tensorflow integration updates

Use TFRecords for TF inference

TF: Don't rescale inputs

Fix some TF classification tests

Remove unnecessary print

Fix TF imports when uninstalled

Fix mean image scale

Fix generic model tests

Fix Torch single image inference

Fix inference

TMP TF Lint

Revert changes in digits-lint script

Lint: ignore tensorflow standard examples

More Lint fixes

* Add gradient hook

* Add memn2n model

* Update memn2n with gradient hooks

* GAN example

* Make batch size variable

* Training/inference paths

* Small update to TF 0.12

* Snapshot names, float inference, restore all vars

* Do not restore global_step or optimizer variables

* Add TB link

* Update GAN network

* Dynamically select inference form

* TF inference: convert images to float

* Update GAN z-gen network

* Small Update model view layout

* Add GAN plug-ins

* Update GAN plug-in to create CelebA dataset

* Add ability to show input in ImageOutput extension

* Add all data to raw data view extension

* Add model for CelebA dataset

* Update GAN data plug-in

* Update all losses in one session

* Remove conversion to .png in GAN data plug-in

* TF Slim Lenet example

Divide input by 255

* Update GAN data plug-in

* Fix TF model snapshot

* Reduce scheduler delays to speed up inference

* Update GAN plugins

* Fix TF tests

* Add API to LmdbReader (used by gan_features.py)

* Save animated gif

* Add GAN walk-through

* Update GAN walkthrough with embeddings video

* Fix GAN view for list encoding

* Add animation task to GAN plugins

* Add view task to see image attributes

* Add comments to GAN models

* Update README

* Fix GAN features script

* GAN app

* Fix DIGITS inference

* Adjust GAN window size automatically

* Add attributes to GAN app

* Move gandisplay.py

* Remove wxpython 3.0 selection

* Fix call to model

* Adding disclaimer
  • Loading branch information
gheinrich authored and ethantang95 committed Jul 10, 2017
1 parent ea25e1a commit b1f2044
Show file tree
Hide file tree
Showing 108 changed files with 7,012 additions and 1,256 deletions.
1 change: 0 additions & 1 deletion digits/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ def config_value(option):
Return the current configuration value for the given option
"""
return option_list[option]

9 changes: 5 additions & 4 deletions digits/config/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,27 @@

import os
import platform
from subprocess import Popen,PIPE
from subprocess import Popen, PIPE

from . import option_list

VARNAME_ENV_TFPY = 'TENSORFLOW_PYTHON'
DEFAULT_PYTHON_EXE = 'python2' # @TODO(tzaman) - use the python executable that was used to launch digits?
DEFAULT_PYTHON_EXE = 'python2' # @TODO(tzaman) - use the python executable that was used to launch digits?

if platform.system() == 'Darwin':
# DYLD_LIBRARY_PATH and LD_LIBRARY_PATH is sometimes stripped, and the cuda libraries might need it
if not "DYLD_LIBRARY_PATH" in os.environ:
if "DYLD_LIBRARY_PATH" not in os.environ:
if "CUDA_HOME" in os.environ:
os.environ["DYLD_LIBRARY_PATH"] = str(os.environ["CUDA_HOME"] + '/lib')


def test_tf_import(python_exe):
"""
Tests if tensorflow can be imported, returns if it went okay and optional error.
"""
p = Popen([python_exe, "-c", "import tensorflow"], stdout=PIPE, stderr=PIPE)
(out, err) = p.communicate()
return p.returncode==0, str(err)
return p.returncode == 0, str(err)

if VARNAME_ENV_TFPY in os.environ:
tf_python_exe = os.environ[VARNAME_ENV_TFPY]
Expand Down
29 changes: 29 additions & 0 deletions digits/dataset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import images as dataset_images
from . import generic
from digits import extensions
from digits.utils.routing import job_from_request, request_wants_json
from digits.webapp import scheduler

Expand Down Expand Up @@ -54,3 +55,31 @@ def summary():
return generic.views.summary(job)
else:
raise werkzeug.exceptions.BadRequest('Invalid job type')


@blueprint.route('/inference-form/<extension_id>/<job_id>', methods=['GET'])
def inference_form(extension_id, job_id):
"""
Returns a rendering of an inference form
"""
inference_form_html = ""

if extension_id != "all-default":
extension_class = extensions.data.get_extension(extension_id)
if not extension_class:
raise RuntimeError("Unable to find data extension with ID=%s"
% job_id.dataset.extension_id)
job = scheduler.get_job(job_id)
if hasattr(job, 'extension_userdata'):
extension_userdata = job.extension_userdata
else:
extension_userdata = {}
extension_userdata.update({'is_inference_db': True})
extension = extension_class(**extension_userdata)

form = extension.get_inference_form()
if form:
template, context = extension.get_inference_template(form)
inference_form_html = flask.render_template_string(template, **context)

return inference_form_html
6 changes: 6 additions & 0 deletions digits/extensions/view/imageOutput/config_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
{{ form.pixel_conversion.tooltip }}
{{ form.pixel_conversion(class='form-control') }}
</div>

<div class="form-group{{mark_errors([form.show_input])}}">
{{ form.show_input.label }}
{{ form.show_input.tooltip }}
{{ form.show_input(class='form-control') }}
</div>
10 changes: 10 additions & 0 deletions digits/extensions/view/imageOutput/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ class ConfigForm(Form):
tooltip='Select method to convert pixel values to the target bit '
'range'
)

show_input = utils.forms.SelectField(
'Show input as image',
choices=[
('yes', 'Yes'),
('no', 'No'),
],
default='no',
tooltip='Show input as image'
)
18 changes: 15 additions & 3 deletions digits/extensions/view/imageOutput/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, dataset, **kwargs):
self.channel_order = kwargs['channel_order'].upper()
self.data_order = kwargs['data_order'].upper()
self.normalize = (kwargs['pixel_conversion'] == 'normalize')
self.show_input = (kwargs['show_input'] == 'yes')

@staticmethod
def get_config_form():
Expand Down Expand Up @@ -70,17 +71,28 @@ def get_view_template(self, data):
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'image': digits.utils.image.embed_image_html(data)}
return self.view_template, {'image_input': digits.utils.image.embed_image_html(data[0]),
'image_output': digits.utils.image.embed_image_html(data[1])}

@override
def process_data(self, input_id, input_data, output_data):
"""
Process one inference and return data to visualize
"""

data = output_data[output_data.keys()[0]].astype('float32')
if self.show_input:
data_input = input_data.astype('float32')
image_input = self.process_image(self.data_order, data_input)
else:
image_input = None

data_output = output_data[output_data.keys()[0]].astype('float32')
image_output = self.process_image(self.data_order, data_output)

return [image_input, image_output]

if self.data_order == 'HWC':
def process_image(self, data_order, data):
if data_order == 'HWC':
data = (data.transpose((2, 0, 1)))

# assume CHW at this point
Expand Down
5 changes: 4 additions & 1 deletion digits/extensions/view/imageOutput/view_template.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

<img src="{{image}}" style="max-width:100%;" />
{% if image_input %}
<img src="{{image_input}}" style="max-width:100%;" />
{% endif %}
<img src="{{image_output}}" style="max-width:100%;" />
5 changes: 4 additions & 1 deletion digits/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from .caffe_framework import CaffeFramework
from .framework import Framework
from .tensorflow_framework import TensorflowFramework
from .torch_framework import TorchFramework
from digits.config import config_value

Expand All @@ -13,6 +12,10 @@
'TorchFramework',
]

if config_value('tensorflow')['enabled']:
from .tensorflow_framework import TensorflowFramework
__all__.append('TensorflowFramework')

#
# create framework instances
#
Expand Down
3 changes: 1 addition & 2 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def can_accumulate_gradients(self):
if config_value('caffe')['flavor'] == 'BVLC':
return True
elif config_value('caffe')['flavor'] == 'NVIDIA':
return (parse_version(config_value('caffe')['version'])
> parse_version('0.14.0-alpha'))
return (parse_version(config_value('caffe')['version']) > parse_version('0.14.0-alpha'))
else:
raise ValueError('Unknown flavor. Support NVIDIA and BVLC flavors only.')
30 changes: 13 additions & 17 deletions digits/frameworks/tensorflow_framework.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import numpy as np
import os
import re
import subprocess
import time
import tempfile

import flask

from .errors import Error, NetworkVisualizationError, BadNetworkError
from .errors import NetworkVisualizationError
from .framework import Framework
import digits
from digits import utils
from digits.config import config_value
from digits.model.tasks import TensorflowTrainTask
from digits.utils import subclass, override, constants


@subclass
class TensorflowFramework(Framework):
"""
Expand All @@ -35,7 +32,7 @@ class TensorflowFramework(Framework):
SUPPORTS_PYTHON_LAYERS_FILE = False
SUPPORTS_TIMELINE_TRACING = True

SUPPORTED_SOLVER_TYPES = ['SGD','ADADELTA','ADAGRAD','ADAGRADDA','MOMENTUM','ADAM','FTRL','RMSPROP']
SUPPORTED_SOLVER_TYPES = ['SGD', 'ADADELTA', 'ADAGRAD', 'ADAGRADDA', 'MOMENTUM', 'ADAM', 'FTRL', 'RMSPROP']

SUPPORTED_DATA_TRANSFORMATION_TYPES = ['MEAN_SUBTRACTION', 'CROPPING']
SUPPORTED_DATA_AUGMENTATION_TYPES = ['FLIPPING', 'NOISE', 'CONTRAST', 'WHITENING', 'HSV_SHIFTING']
Expand All @@ -50,7 +47,7 @@ def create_train_task(self, **kwargs):
"""
create train task
"""
return TensorflowTrainTask(framework_id = self.framework_id, **kwargs)
return TensorflowTrainTask(framework_id=self.framework_id, **kwargs)

@override
def get_standard_network_desc(self, network):
Expand Down Expand Up @@ -126,10 +123,10 @@ def get_network_visualization(self, **kwargs):
# Another for the HTML
_, temp_html_path = tempfile.mkstemp(suffix='.html')

try: # do this in a try..finally clause to make sure we delete the temp file
try: # do this in a try..finally clause to make sure we delete the temp file
# build command line
args = [config_value('tensorflow')['executable'],
os.path.join(os.path.dirname(digits.__file__),'tools','tensorflow','main.py'),
os.path.join(os.path.dirname(digits.__file__), 'tools', 'tensorflow', 'main.py'),
'--network=%s' % os.path.basename(temp_network_path),
'--networkDirectory=%s' % os.path.dirname(temp_network_path),
'--visualizeModelPath=%s' % temp_graphdef_path,
Expand All @@ -141,7 +138,7 @@ def get_network_visualization(self, **kwargs):

if use_mean and use_mean != 'none':
mean_file = dataset.get_mean_file()
assert mean_file != None, 'Failed to retrieve mean file.'
assert mean_file is not None, 'Failed to retrieve mean file.'
args.append('--subtractMean=%s' % use_mean)
args.append('--mean=%s' % dataset.path(mean_file))

Expand All @@ -163,15 +160,14 @@ def get_network_visualization(self, **kwargs):

env = os.environ.copy()
# make only a selected number of GPUs visible. The ID is not important for just the vis
env['CUDA_VISIBLE_DEVICES'] = ",".join([str(i) for i in range(0,int(num_gpus))])
env['CUDA_VISIBLE_DEVICES'] = ",".join([str(i) for i in range(0, int(num_gpus))])

# execute command
p = subprocess.Popen(args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True,
env=env
)
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True,
env=env)

stdout_log = ''
while p.poll() is None:
Expand All @@ -181,7 +177,7 @@ def get_network_visualization(self, **kwargs):
stdout_log += line
if p.returncode:
raise NetworkVisualizationError(stdout_log)
else: # Success!
else: # Success!
return repr(str(open(temp_graphdef_path).read()))
finally:
os.remove(temp_network_path)
Expand Down
24 changes: 13 additions & 11 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ def validate_py_ext(form, field):
tooltip="How many epochs of training between running through one pass of the validation data?"
)

traces_interval = utils.forms.IntegerField('Tracing Interval (in steps)',
validators=[
validators.NumberRange(min=0)
],
default=0,
tooltip="Generation of a timeline trace every few steps"
)
traces_interval = utils.forms.IntegerField(
'Tracing Interval (in steps)',
validators=[
validators.NumberRange(min=0)
],
default=0,
tooltip="Generation of a timeline trace every few steps"
)

random_seed = utils.forms.IntegerField(
'Random seed',
Expand Down Expand Up @@ -311,10 +312,11 @@ def validate_lr_multistep_values(form, field):
)

def validate_custom_network_snapshot(form, field):
if form.method.data == 'custom':
for filename in field.data.strip().split(os.path.pathsep):
if filename and not os.path.exists(filename):
raise validators.ValidationError('File "%s" does not exist' % filename)
pass
#if form.method.data == 'custom':
# for filename in field.data.strip().split(os.path.pathsep):
# if filename and not os.path.exists(filename):
# raise validators.ValidationError('File "%s" does not exist' % filename)

# Select one of several GPUs
select_gpu = wtforms.RadioField(
Expand Down
2 changes: 1 addition & 1 deletion digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = task.get_snapshot(epoch)
snapshot_filename = task.get_snapshot(epoch, download=True)

# get model files
model_files = task.get_model_files()
Expand Down
Loading

0 comments on commit b1f2044

Please sign in to comment.