Skip to content

Commit

Permalink
Fix #1 - pretrained model selection bug
Browse files Browse the repository at this point in the history
1) Removed the select field for previous networks snapshots from forms.py, due to wtforms validation issues. Then included code in views.py and new.html, to display the select fields for previous networks snapshots in html and for validating these fields. These changes fix the pretrained model selection bug. Also these changes ensure the display of error message, incase user selects the job that doesn't exists (this scenario happens when a user selects the job and at the same if another user deletes this job)
2) fix to the following issue: when validation fails, the selected radio buttons in "previous networks" tab remains selected but the select input field and customize button beside the selected previous network is in hidden status, which is wrong. same issue exists in standard networks tab. Modified new.html to ensure that select field and customize button beside selected radio button will be visible even when validation fails.
3) Modified the test case related to pretrained model selection.
4) Included the changes to display the selected "pretrained model" details in show.html file
  • Loading branch information
Sravan2j committed Mar 23, 2015
1 parent 4f40cd3 commit 314df80
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 39 deletions.
16 changes: 10 additions & 6 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def __init__(self, csrf_enabled=False, *args, **kwargs):
super(ModelForm, self).__init__(csrf_enabled=csrf_enabled, *args, **kwargs)

### Methods
def selection_exists_in_choices(form, field):
found=False
for i, choice in enumerate(field.choices):
if choice[0] == field.data:
found=True
if found == False:
raise validators.ValidationError("Selected job doesn't exist. Maybe it was deleted by another user.")

def required_if_method(value):
def _required(form, field):
Expand Down Expand Up @@ -144,15 +151,10 @@ def validate_lr_multistep_values(form, field):
choices = [],
validators = [
required_if_method('previous'),
selection_exists_in_choices,
],
)

previous_network_snapshots = wtforms.FieldList(
wtforms.SelectField('Snapshots',
validators = [validators.Optional()]
),
)

custom_network = wtforms.TextAreaField('Custom Network',
validators = [
required_if_method('custom'),
Expand All @@ -175,3 +177,5 @@ def validate_custom_network_snapshot(form, field):
]
)



20 changes: 14 additions & 6 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2014-2015, NVIDIA CORPORATION. All rights reserved.

import re
import os
import tempfile

import unittest
import mock
Expand Down Expand Up @@ -52,11 +54,20 @@ def setupClass(cls):
mj = mock.Mock(spec=digits.model.ImageClassificationModelJob)
mj.id.return_value = 'model'
mj.name.return_value = ''
mj.train_task.return_value.snapshots = [('path', 1)]
_, cls.temp_snapshot_path = tempfile.mkstemp() #instead of using a dummy hardcoded value as snapshot path, temp file path is used to avoid the filen't exists exception in views.py.
mj.train_task.return_value.snapshots = [(cls.temp_snapshot_path, 1)]
mj.train_task.return_value.network = caffe_pb2.NetParameter()

digits.webapp.scheduler.jobs = [dj, mj]

@classmethod
def tearDownClass(cls):
super(TestCreate, cls).tearDownClass()
try:
os.remove(cls.temp_snapshot_path)
except OSError:
pass

def test_empty_request(self):
"""empty request"""
rv = self.app.post(self.url)
Expand All @@ -76,7 +87,6 @@ def test_crop_size(self):

assert scheduler.jobs[-1].train_task().crop_size == 12

@unittest.skip('expected failure')
def test_previous_network_pretrained_model(self):
"""previous network, pretrained model"""

Expand All @@ -85,9 +95,7 @@ def test_previous_network_pretrained_model(self):
'model_name': 'test',
'dataset': 'dataset',
'previous_networks': 'model',
# TODO: select snapshot 1
'model-snapshot' : 1
})

assert scheduler.jobs[-1].train_task().pretrained_model == 'path'
assert False

assert scheduler.jobs[-1].train_task().pretrained_model == self.temp_snapshot_path
32 changes: 19 additions & 13 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def image_classification_model_new():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
set_previous_network_snapshots(form)

return render_template('models/images/classification/new.html', form=form, has_datasets=(len(get_datasets())==0))
prev_network_snapshots = get_previous_network_snapshots()

return render_template('models/images/classification/new.html', form=form, previous_network_snapshots=prev_network_snapshots, has_datasets=(len(get_datasets())==0))

@app.route(NAMESPACE, methods=['POST'])
def image_classification_model_create():
Expand All @@ -42,10 +43,11 @@ def image_classification_model_create():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
set_previous_network_snapshots(form)

prev_network_snapshots = get_previous_network_snapshots()

if not form.validate_on_submit():
return render_template('models/images/classification/new.html', form=form), 400
return render_template('models/images/classification/new.html', form=form, previous_network_snapshots=prev_network_snapshots), 400

datasetJob = scheduler.get_job(form.dataset.data)
if not datasetJob:
Expand Down Expand Up @@ -81,13 +83,19 @@ def image_classification_model_create():
network.CopyFrom(old_job.train_task().network)
for i, choice in enumerate(form.previous_networks.choices):
if choice[0] == form.previous_networks.data:
epoch = form.previous_network_snapshots[i].data
if epoch != 'none':
epoch = int(request.form['%s-snapshot' % form.previous_networks.data])
if epoch != 0:
for filename, e in old_job.train_task().snapshots:
if e == epoch:
pretrained_model = filename
break

if pretrained_model is None:
raise Exception("For the job %s, selected pretrained_model for epoch %d is invalid!" % (form.previous_networks.data, epoch))
if not (os.path.exists(pretrained_model)):
raise Exception("Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break

elif form.method.data == 'custom':
text_format.Merge(form.custom_network.data, network)
pretrained_model = form.custom_network_snapshot.data.strip()
Expand Down Expand Up @@ -282,13 +290,11 @@ def get_previous_networks():
)
]

def set_previous_network_snapshots(form):
while len(form.previous_network_snapshots):
form.previous_network_snapshots.pop_entry()

def get_previous_network_snapshots():
prev_network_snapshots = []
for job_id, _ in get_previous_networks():
job = scheduler.get_job(job_id)
e = form.previous_network_snapshots.append_entry()
e.choices = [('none', 'None')] + [(epoch, 'Epoch #%s' % epoch)
e = [(0, 'None')] + [(epoch, 'Epoch #%s' % epoch)
for _, epoch in reversed(job.train_task().snapshots)]

prev_network_snapshots.append(e)
return prev_network_snapshots
47 changes: 33 additions & 14 deletions digits/templates/models/images/classification/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,16 @@ <h4>Solver Options</h4>
<a href="{{url_for("models_show", job_id=network.data)}}" target="_blank">View</a>
</td>
<td>
{% set snapshot_list = form.previous_network_snapshots[loop.index0] %}
{% if snapshot_list.choices|length %}
{{snapshot_list(class='form-control')}}
{% set snapshot_list = previous_network_snapshots[loop.index0] %}
{% if snapshot_list|length %}
<select class="form-control" id="{{network.data}}-snapshot" name="{{network.data}}-snapshot">
{% for each_epoch in snapshot_list %}
<option value="{{each_epoch[0]}}">{{each_epoch[1]}}</option>
{% endfor %}
</select>
{% endif %}
</td>
<td><a class="btn btn-sm" onClick="customizeNetwork('{{network.data}}', '{{snapshot_list.id}}');">Customize</a></td>
<td><a class="btn btn-sm" onClick="customizeNetwork('{{network.data}}', '{{network.data}}-snapshot');">Customize</a></td>
</tr>
{% else %}
<tr>
Expand All @@ -436,26 +440,41 @@ <h4>Solver Options</h4>
</table>
</div>
<script>
$(".btn.btn-sm").hide(); // hide all the buttons
$("#network-tab-previous").find(".form-control").prop('disabled', 'disabled'); //hide all the select elements in "Previous Networks" tab
$(".btn.btn-sm").css('visibility', 'hidden'); // hide all the buttons
$("#network-tab-previous").find(".form-control").css('visibility', 'hidden'); //hide all the select elements in "Previous Networks" tab

var $stdtab_prev_clicked_tr = null;
$("input:radio[name=standard_networks]").click(function(){
if ($stdtab_prev_clicked_tr != null) {
$stdtab_prev_clicked_tr.find(".btn.btn-sm").hide();
}
$(this).parents('tr').find(".btn.btn-sm").show();
$stdtab_prev_clicked_tr = $(this).parents('tr');
$stdtab_prev_clicked_tr.find(".btn.btn-sm").css('visibility', 'hidden');
}
$(this).parents('tr').find(".btn.btn-sm").css('visibility', 'visible');
$stdtab_prev_clicked_tr = $(this).parents('tr');
});

var $prevtab_prev_clicked_tr = null;
$("input:radio[name=previous_networks]").click(function(){
if ($prevtab_prev_clicked_tr != null) {
$prevtab_prev_clicked_tr.find(".btn.btn-sm").hide();
$prevtab_prev_clicked_tr.find(".form-control").prop('disabled', 'disabled');
$prevtab_prev_clicked_tr.find(".btn.btn-sm").css('visibility', 'hidden');
$prevtab_prev_clicked_tr.find(".form-control").css('visibility', 'hidden');
}
$(this).parents('tr').find(".btn.btn-sm").show();
$(this).parents('tr').find(".form-control").prop('disabled', false);
$(this).parents('tr').find(".btn.btn-sm").css('visibility', 'visible');
$(this).parents('tr').find(".form-control").css('visibility', 'visible');
$prevtab_prev_clicked_tr = $(this).parents('tr');
});

//fix to the following issue: when validation fails and the screen was re-displayed, the selected radio buttons in "previous networks" tab remains selected but the select input field and customize button beside the selected previous network is in hidden status, which is wrong. same issue exists in standard networks tab. The below code ensures that select field and customize button beside selected radio button will be visible even when validation fails.

if($("input:radio[name=standard_networks]").is(":checked")) {
$stdtab_prev_clicked_tr = $("input:radio[name=standard_networks]:checked").parents('tr');
$stdtab_prev_clicked_tr.find(".btn.btn-sm").css('visibility', 'visible');
}
if($("input:radio[name=previous_networks]").is(":checked")) {
$prevtab_prev_clicked_tr = $("input:radio[name=previous_networks]:checked").parents('tr');
$prevtab_prev_clicked_tr.find(".btn.btn-sm").css('visibility', 'visible');
$prevtab_prev_clicked_tr.find(".form-control").css('visibility', 'visible');
}

</script>
<div id="network-tab-custom" class="tab-pane">
<script>
Expand Down
4 changes: 4 additions & 0 deletions digits/templates/models/images/classification/show.html
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
<dd><a href="{{url_for('serve_file', path=task.path(task.deploy_file, relative=True))}}">{{task.deploy_file}}</a></dd>
<dt>Raw caffe output</dt>
<dd><a href="{{url_for('serve_file', path=task.path(task.caffe_log_file, relative=True))}}">{{task.caffe_log_file}}</a></dd>
{% if task.pretrained_model %}
<dt>Pretrained Model</dt>
<dd>{{task.pretrained_model}}</dd>
{% endif %}
</dl>
</div>
</div>
Expand Down

0 comments on commit 314df80

Please sign in to comment.