forked from taoxugit/AttnGAN
-
Notifications
You must be signed in to change notification settings - Fork 39
/
main.py
93 lines (77 loc) · 2.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import time
import random
from eval import *
from flask import Flask, jsonify, request, abort
from applicationinsights import TelemetryClient
from applicationinsights.requests import WSGIApplication
from applicationinsights.exceptions import enable
from miscc.config import cfg
#from werkzeug.contrib.profiler import ProfilerMiddleware
enable(os.environ["TELEMETRY"])
app = Flask(__name__)
app.wsgi_app = WSGIApplication(os.environ["TELEMETRY"], app.wsgi_app)
@app.route('/api/v1.0/bird', methods=['POST'])
def create_bird():
if not request.json or not 'caption' in request.json:
abort(400)
caption = request.json['caption']
t0 = time.time()
urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
t1 = time.time()
response = {
'small': urls[0],
'medium': urls[1],
'large': urls[2],
'map1': urls[3],
'map2': urls[4],
'caption': caption,
'elapsed': t1 - t0
}
return jsonify({'bird': response}), 201
@app.route('/api/v1.0/birds', methods=['POST'])
def create_birds():
if not request.json or not 'caption' in request.json:
abort(400)
caption = request.json['caption']
t0 = time.time()
urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=6)
t1 = time.time()
response = {
'bird1' : { 'small': urls[0], 'medium': urls[1], 'large': urls[2] },
'bird2' : { 'small': urls[3], 'medium': urls[4], 'large': urls[5] },
'bird3' : { 'small': urls[6], 'medium': urls[7], 'large': urls[8] },
'bird4' : { 'small': urls[9], 'medium': urls[10], 'large': urls[11] },
'bird5' : { 'small': urls[12], 'medium': urls[13], 'large': urls[14] },
'bird6' : { 'small': urls[15], 'medium': urls[16], 'large': urls[17] },
'caption': caption,
'elapsed': t1 - t0
}
return jsonify({'bird': response}), 201
@app.route('/', methods=['GET'])
def get_bird():
return 'Version 1'
if __name__ == '__main__':
t0 = time.time()
tc = TelemetryClient(os.environ["TELEMETRY"])
# gpu based
cfg.CUDA = os.environ["GPU"].lower() == 'true'
tc.track_event('container initializing', {"CUDA": str(cfg.CUDA)})
# load word dictionaries
wordtoix, ixtoword = word_index()
# lead models
text_encoder, netG = models(len(wordtoix))
# load blob service
blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"])
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if cfg.CUDA:
torch.cuda.manual_seed_all(seed)
#app.config['PROFILE'] = True
#app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])
#app.run(host='0.0.0.0', port=8080, debug = True)
t1 = time.time()
tc.track_event('container start', {"starttime": str(t1-t0)})
app.run(host='0.0.0.0', port=8080)