Skip to content

Commit

Permalink
Add HTTPS support to APIs (openai and default) (#4270)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: oobabooga <[email protected]>
  • Loading branch information
chuyqa and oobabooga authored Oct 13, 2023
1 parent 43be1be commit ed66ca3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
19 changes: 15 additions & 4 deletions extensions/api/blocking_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

Expand All @@ -14,6 +15,7 @@
stop_everything_event
)
from modules.utils import get_available_models
from modules.logging_colors import logger


def get_model_info():
Expand Down Expand Up @@ -199,20 +201,29 @@ def end_headers(self):

def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'

server = ThreadingHTTPServer((address, port), Handler)

ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)

def on_start(public_url: str):
print(f'Starting non-streaming server at public url {public_url}/api')
logger.info(f'Starting non-streaming server at public url {public_url}/api')

if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
print(
f'Starting API at http://{address}:{port}/api')
if ssl_verify:
logger.info(f'Starting API at https://{address}:{port}/api')
else:
logger.info(f'Starting API at http://{address}:{port}/api')

server.serve_forever()

Expand Down
28 changes: 23 additions & 5 deletions extensions/api/streaming_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import json
import ssl
from threading import Thread

from websockets.server import serve

from extensions.api.util import (
build_parameters,
try_start_cloudflared,
Expand All @@ -10,7 +13,7 @@
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
from websockets.server import serve
from modules.logging_colors import logger

PATH = '/api/v1/stream'

Expand Down Expand Up @@ -98,24 +101,39 @@ async def _handle_connection(websocket, path):


async def _run(host: str, port: int):
async with serve(_handle_connection, host, port, ping_interval=None):
await asyncio.Future() # run forever
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
context = None

async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
await asyncio.Future() # Run the server forever


def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False

def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
print(f'Starting streaming server at public url {public_url}{PATH}')
logger.info(f'Starting streaming server at public url {public_url}{PATH}')

if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception as e:
print(e)
else:
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
if ssl_verify:
logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}')
else:
logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}')

asyncio.run(_run(host=address, port=port))

Expand Down
17 changes: 15 additions & 2 deletions extensions/openai/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import ssl
import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
Expand Down Expand Up @@ -322,6 +323,15 @@ def run_server():
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler)

ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)

if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
Expand All @@ -330,8 +340,11 @@ def run_server():
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')

if ssl_verify:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')

server.serve_forever()


Expand Down

0 comments on commit ed66ca3

Please sign in to comment.