diff --git a/Dockerfile.server b/Dockerfile.server index bc576c74..e0b49481 100644 --- a/Dockerfile.server +++ b/Dockerfile.server @@ -22,4 +22,4 @@ RUN pip install -r requirements.txt ENV PYTHONPATH='/src/:$PYTHONPATH' -ENTRYPOINT ["entrypoints/entrypoint.server.sh"] +ENTRYPOINT ["entrypoints/entrypoint.predict.server.sh"] diff --git a/entrypoints/entrypoint.predict.server.sh b/entrypoints/entrypoint.predict.server.sh new file mode 100755 index 00000000..e054ec16 --- /dev/null +++ b/entrypoints/entrypoint.predict.server.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + +BASE_MODEL_NAME=$1 +WEIGHTS_FILE=$2 + +# predict +python -m evaluater.server \ +--base-model-name $BASE_MODEL_NAME \ +--weights-file $WEIGHTS_FILE diff --git a/entrypoints/entrypoint.server.sh b/entrypoints/entrypoint.server.sh deleted file mode 100755 index 313408ec..00000000 --- a/entrypoints/entrypoint.server.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -set -e - -# predict -python server.py diff --git a/src/evaluater/predict.py b/src/evaluater/predict.py index 28d8a781..375bdb1f 100644 --- a/src/evaluater/predict.py +++ b/src/evaluater/predict.py @@ -5,15 +5,11 @@ from utils.utils import calc_mean_score, save_json from handlers.model_builder import Nima from handlers.data_generator import TestDataGenerator -from keras import backend as K -from PIL import ImageFile, Image def image_file_to_json(img_path): img_dir = os.path.dirname(img_path) - splits = os.path.basename(img_path).split('.') - splits.pop() - img_id = ".".join(splits) + img_id = os.path.basename(img_path).split('.')[0] return img_dir, [{'image_id': img_id}] @@ -23,9 +19,7 @@ def image_dir_to_json(img_dir, img_type='jpg'): samples = [] for img_path in img_paths: - splits = os.path.basename(img_path).split('.') - splits.pop() - img_id = ".".join(splits) + img_id = os.path.basename(img_path).split('.')[0] samples.append({'image_id': img_id}) return samples @@ -43,9 +37,6 @@ def main(base_model_name, weights_file, image_source, predictions_file, img_form image_dir = image_source samples = image_dir_to_json(image_dir, img_type='jpg') - ImageFile.LOAD_TRUNCATED_IMAGES = True - Image.MAX_IMAGE_PIXELS = None - # build model and load weights nima = Nima(base_model_name, weights=None) nima.build() @@ -57,17 +48,16 @@ def main(base_model_name, weights_file, image_source, predictions_file, img_form # get predictions predictions = predict(nima.nima_model, data_generator) - K.clear_session() # calc mean scores and add to samples for i, sample in enumerate(samples): sample['mean_score_prediction'] = calc_mean_score(predictions[i]) + print(json.dumps(samples, indent=2)) + if predictions_file is not None: save_json(samples, predictions_file) - return samples - if __name__ == '__main__': diff --git a/src/evaluater/server.py b/src/evaluater/server.py new file mode 100644 index 00000000..fdbc6aa4 --- /dev/null +++ b/src/evaluater/server.py @@ -0,0 +1,86 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import os +from flask import Flask, request, jsonify +from evaluater.predict import image_file_to_json, image_dir_to_json, predict +import urllib +import shutil +import argparse +from keras import backend as K +from PIL import ImageFile, Image +from handlers.model_builder import Nima +from handlers.data_generator import TestDataGenerator + +app = Flask('server') + +def load_model(config): + global model + model = Nima(config.base_model_name) + model.build() + model.nima_model.load_weights(config.weights_file) + model.nima_model._make_predict_function() # https://github.com/keras-team/keras/issues/6462 + model.nima_model.summary() + +def main(image_source, predictions_file, img_format='jpg'): + # load samples + if os.path.isfile(image_source): + image_dir, samples = image_file_to_json(image_source) + else: + image_dir = image_source + samples = image_dir_to_json(image_dir, img_type='jpg') + + ImageFile.LOAD_TRUNCATED_IMAGES = True + Image.MAX_IMAGE_PIXELS = None + + # initialize data generator + data_generator = TestDataGenerator(samples, image_dir, 64, 10, model.preprocessing_function(), + img_format=img_format) + + # get predictions + predictions = predict(model.nima_model, data_generator) + K.clear_session() + + # calc mean scores and add to samples + for i, sample in enumerate(samples): + sample['mean_score_prediction'] = calc_mean_score(predictions[i]) + + if predictions_file is not None: + save_json(samples, predictions_file) + + return samples + +@app.route('/prediction', methods=['POST']) +def prediction(): + + global images + + if request.method == 'POST': + images = request.json + + if images: + + shutil.rmtree('temp') + os.mkdir('temp') + for image in images: + filename_w_ext = os.path.basename(image) + try: + urllib.request.urlretrieve(image, 'temp/'+ filename_w_ext) + except: + print('An exception occurred :' + image) + + result = main('temp', None) + + return jsonify(result) + + return jsonify({'error': 'Image is not available'}) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-b', '--base-model-name', help='CNN base model name', required=True) + parser.add_argument('-w', '--weights-file', help='path of weights file', required=True) + args = parser.parse_args() + + load_model(args) + app.run(host='0.0.0.0', port=5005) diff --git a/src/server.py b/src/server.py deleted file mode 100755 index 4719dad9..00000000 --- a/src/server.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import sys -import os -from flask import Flask, redirect, url_for, render_template, request, \ - abort, jsonify -from evaluater.predict import main -import json -import urllib.request -import shutil -import os - -app = Flask('server') - - -@app.route('/query/', methods=['POST']) -def query(model): - - global images - - if request.method == 'POST': - images = request.json - - if images: - - shutil.rmtree('temp') - os.mkdir('temp') - for image in images: - filename_w_ext = os.path.basename(image) - try: - urllib.request.urlretrieve(image, 'temp/'+ filename_w_ext) - except: - print('An exception occurred :' + image) - - if model == 'technical': - model = '/models/MobileNet/weights_mobilenet_technical_0.11.hdf5' - else: - model = '/models/MobileNet/weights_mobilenet_aesthetic_0.07.hdf5' - - result = main('MobileNet', - model, - 'temp', - None) - - return jsonify(result) - - return jsonify({'error': 'Image is not available'}) - -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5005)