Skip to content

Commit

Permalink
Fix docker http server issues and reorganise entrypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
spandan committed Aug 28, 2019
1 parent 865623d commit 3e5f3b5
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.server
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ RUN pip install -r requirements.txt

ENV PYTHONPATH='/src/:$PYTHONPATH'

ENTRYPOINT ["entrypoints/entrypoint.server.sh"]
ENTRYPOINT ["entrypoints/entrypoint.predict.server.sh"]
10 changes: 10 additions & 0 deletions entrypoints/entrypoint.predict.server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e

BASE_MODEL_NAME=$1
WEIGHTS_FILE=$2

# predict
python -m evaluater.server \
--base-model-name $BASE_MODEL_NAME \
--weights-file $WEIGHTS_FILE
5 changes: 0 additions & 5 deletions entrypoints/entrypoint.server.sh

This file was deleted.

18 changes: 4 additions & 14 deletions src/evaluater/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
from utils.utils import calc_mean_score, save_json
from handlers.model_builder import Nima
from handlers.data_generator import TestDataGenerator
from keras import backend as K
from PIL import ImageFile, Image


def image_file_to_json(img_path):
img_dir = os.path.dirname(img_path)
splits = os.path.basename(img_path).split('.')
splits.pop()
img_id = ".".join(splits)
img_id = os.path.basename(img_path).split('.')[0]

return img_dir, [{'image_id': img_id}]

Expand All @@ -23,9 +19,7 @@ def image_dir_to_json(img_dir, img_type='jpg'):

samples = []
for img_path in img_paths:
splits = os.path.basename(img_path).split('.')
splits.pop()
img_id = ".".join(splits)
img_id = os.path.basename(img_path).split('.')[0]
samples.append({'image_id': img_id})

return samples
Expand All @@ -43,9 +37,6 @@ def main(base_model_name, weights_file, image_source, predictions_file, img_form
image_dir = image_source
samples = image_dir_to_json(image_dir, img_type='jpg')

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# build model and load weights
nima = Nima(base_model_name, weights=None)
nima.build()
Expand All @@ -57,17 +48,16 @@ def main(base_model_name, weights_file, image_source, predictions_file, img_form

# get predictions
predictions = predict(nima.nima_model, data_generator)
K.clear_session()

# calc mean scores and add to samples
for i, sample in enumerate(samples):
sample['mean_score_prediction'] = calc_mean_score(predictions[i])

print(json.dumps(samples, indent=2))

if predictions_file is not None:
save_json(samples, predictions_file)

return samples


if __name__ == '__main__':

Expand Down
86 changes: 86 additions & 0 deletions src/evaluater/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

import os
from flask import Flask, request, jsonify
from evaluater.predict import image_file_to_json, image_dir_to_json, predict
import urllib
import shutil
import argparse
from keras import backend as K
from PIL import ImageFile, Image
from handlers.model_builder import Nima
from handlers.data_generator import TestDataGenerator

app = Flask('server')

def load_model(config):
global model
model = Nima(config.base_model_name)
model.build()
model.nima_model.load_weights(config.weights_file)
model.nima_model._make_predict_function() # https://github.com/keras-team/keras/issues/6462
model.nima_model.summary()

def main(image_source, predictions_file, img_format='jpg'):
# load samples
if os.path.isfile(image_source):
image_dir, samples = image_file_to_json(image_source)
else:
image_dir = image_source
samples = image_dir_to_json(image_dir, img_type='jpg')

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

# initialize data generator
data_generator = TestDataGenerator(samples, image_dir, 64, 10, model.preprocessing_function(),
img_format=img_format)

# get predictions
predictions = predict(model.nima_model, data_generator)
K.clear_session()

# calc mean scores and add to samples
for i, sample in enumerate(samples):
sample['mean_score_prediction'] = calc_mean_score(predictions[i])

if predictions_file is not None:
save_json(samples, predictions_file)

return samples

@app.route('/prediction', methods=['POST'])
def prediction():

global images

if request.method == 'POST':
images = request.json

if images:

shutil.rmtree('temp')
os.mkdir('temp')
for image in images:
filename_w_ext = os.path.basename(image)
try:
urllib.request.urlretrieve(image, 'temp/'+ filename_w_ext)
except:
print('An exception occurred :' + image)

result = main('temp', None)

return jsonify(result)

return jsonify({'error': 'Image is not available'})

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('-b', '--base-model-name', help='CNN base model name', required=True)
parser.add_argument('-w', '--weights-file', help='path of weights file', required=True)
args = parser.parse_args()

load_model(args)
app.run(host='0.0.0.0', port=5005)
51 changes: 0 additions & 51 deletions src/server.py

This file was deleted.

0 comments on commit 3e5f3b5

Please sign in to comment.