-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix docker http server issues and reorganise entrypoints
- Loading branch information
spandan
committed
Aug 28, 2019
1 parent
865623d
commit 3e5f3b5
Showing
6 changed files
with
101 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
BASE_MODEL_NAME=$1 | ||
WEIGHTS_FILE=$2 | ||
|
||
# predict | ||
python -m evaluater.server \ | ||
--base-model-name $BASE_MODEL_NAME \ | ||
--weights-file $WEIGHTS_FILE |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#!/usr/bin/python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
from flask import Flask, request, jsonify | ||
from evaluater.predict import image_file_to_json, image_dir_to_json, predict | ||
import urllib | ||
import shutil | ||
import argparse | ||
from keras import backend as K | ||
from PIL import ImageFile, Image | ||
from handlers.model_builder import Nima | ||
from handlers.data_generator import TestDataGenerator | ||
|
||
app = Flask('server') | ||
|
||
def load_model(config): | ||
global model | ||
model = Nima(config.base_model_name) | ||
model.build() | ||
model.nima_model.load_weights(config.weights_file) | ||
model.nima_model._make_predict_function() # https://github.com/keras-team/keras/issues/6462 | ||
model.nima_model.summary() | ||
|
||
def main(image_source, predictions_file, img_format='jpg'): | ||
# load samples | ||
if os.path.isfile(image_source): | ||
image_dir, samples = image_file_to_json(image_source) | ||
else: | ||
image_dir = image_source | ||
samples = image_dir_to_json(image_dir, img_type='jpg') | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
Image.MAX_IMAGE_PIXELS = None | ||
|
||
# initialize data generator | ||
data_generator = TestDataGenerator(samples, image_dir, 64, 10, model.preprocessing_function(), | ||
img_format=img_format) | ||
|
||
# get predictions | ||
predictions = predict(model.nima_model, data_generator) | ||
K.clear_session() | ||
|
||
# calc mean scores and add to samples | ||
for i, sample in enumerate(samples): | ||
sample['mean_score_prediction'] = calc_mean_score(predictions[i]) | ||
|
||
if predictions_file is not None: | ||
save_json(samples, predictions_file) | ||
|
||
return samples | ||
|
||
@app.route('/prediction', methods=['POST']) | ||
def prediction(): | ||
|
||
global images | ||
|
||
if request.method == 'POST': | ||
images = request.json | ||
|
||
if images: | ||
|
||
shutil.rmtree('temp') | ||
os.mkdir('temp') | ||
for image in images: | ||
filename_w_ext = os.path.basename(image) | ||
try: | ||
urllib.request.urlretrieve(image, 'temp/'+ filename_w_ext) | ||
except: | ||
print('An exception occurred :' + image) | ||
|
||
result = main('temp', None) | ||
|
||
return jsonify(result) | ||
|
||
return jsonify({'error': 'Image is not available'}) | ||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-b', '--base-model-name', help='CNN base model name', required=True) | ||
parser.add_argument('-w', '--weights-file', help='path of weights file', required=True) | ||
args = parser.parse_args() | ||
|
||
load_model(args) | ||
app.run(host='0.0.0.0', port=5005) |
This file was deleted.
Oops, something went wrong.