Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use websockets for json communication #4490

Merged
merged 41 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3221c90
first commit
Jun 14, 2023
9c0ec8d
Fix bugs
AndrewJGaut Jun 14, 2023
a139498
Fix more bugsz
AndrewJGaut Jun 14, 2023
b226f40
slight change
AndrewJGaut Jun 14, 2023
93f3e62
minor bug fixes
Jun 14, 2023
0db9264
It's working quite well now
AndrewJGaut Jun 28, 2023
1a65356
merge conflict
AndrewJGaut Jun 28, 2023
ae3e3be
Update to solve that last bug. I think it should work now
AndrewJGaut Jun 28, 2023
85774a4
bug fix that was causing messages to be send to tall worker threads
AndrewJGaut Jun 29, 2023
81064a7
Fix unittests issue and fix formatting
AndrewJGaut Jul 3, 2023
5326420
Minor change to increase robustness of message sending
AndrewJGaut Jul 3, 2023
34550f0
Some more formatting changes
AndrewJGaut Jul 3, 2023
57aad3d
Add in worker auth and use wss rather than ws for websocket URLs to m…
AndrewJGaut Jul 12, 2023
42127e5
Added in server auth with secret. Aslo still need to test (still on p…
AndrewJGaut Jul 12, 2023
aa37b0e
Fixed issues and got auth working. Now, I'll work on returning error …
AndrewJGaut Jul 12, 2023
3e71d67
Add in tests for authentication functionality (for worker and server)
AndrewJGaut Jul 12, 2023
7a0bd70
Slight cleanup to data sending code
AndrewJGaut Jul 13, 2023
cbd1505
Adding in ssl certification
AndrewJGaut Jul 13, 2023
4eb3d7b
Add in SSL stuff for worker; still testing on dev
AndrewJGaut Jul 14, 2023
aa74981
Revert "Add in SSL stuff for worker; still testing on dev"
AndrewJGaut Jul 14, 2023
123fc7e
Revert "Adding in ssl certification"
AndrewJGaut Jul 14, 2023
6efc0cf
Fixed formatting
AndrewJGaut Jul 15, 2023
35bdab0
Very minor formatting change to ignore one line for MyPy
AndrewJGaut Jul 15, 2023
9f5417c
Another minor formatting change
AndrewJGaut Jul 15, 2023
f0ec0f1
Merge branch 'master' into use-websockets-for-json-communication
AndrewJGaut Jul 16, 2023
baf9121
add exponential backoff to see if that fixes dev issue
AndrewJGaut Jul 16, 2023
6222950
Merge branch 'use-websockets-for-json-communication' of github.com:co…
AndrewJGaut Jul 16, 2023
100d536
Added code to actually detect worker disconnections now so that some …
AndrewJGaut Jul 16, 2023
2a82155
Make sockets get looped over in random order to help distribute load
AndrewJGaut Jul 16, 2023
5dabf6d
Clean up ws-server and delete a Dataclass I was using previously
AndrewJGaut Jul 17, 2023
0c57229
a few more minor changes
AndrewJGaut Jul 17, 2023
e34fb72
Make websocket locks more robust and improve error messaging
AndrewJGaut Jul 17, 2023
42819a6
More permissible retries in case of other errors (e.g. like 1013). Wi…
AndrewJGaut Jul 17, 2023
8e868a4
minor change to have a different error message if worker doesn't yet …
AndrewJGaut Jul 17, 2023
f36ed3a
Rename send_json and send_json_message_with_sock
AndrewJGaut Aug 16, 2023
4ed30a2
Rearrange worker_model to minimize diff
AndrewJGaut Aug 16, 2023
d115a96
Final changes
AndrewJGaut Aug 16, 2023
5ac4853
Fix formatting
AndrewJGaut Aug 16, 2023
8acedff
Minor changes to get auth working again and to robustly return an err…
AndrewJGaut Aug 23, 2023
3d25b77
Merge branch 'master' of github.com:codalab/codalab-worksheets into u…
AndrewJGaut Sep 10, 2023
bbe9402
Merge in master and make some minor changes
AndrewJGaut Sep 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 117 additions & 36 deletions codalab/bin/ws_server.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,156 @@
# Main entry point for CodaLab cl-ws-server.
import argparse
import asyncio
from collections import defaultdict
import logging
import os
import random
import re
import time
from typing import Any, Dict
import websockets
import threading

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')

worker_to_ws: Dict[str, Any] = {}
from codalab.lib.codalab_manager import CodaLabManager


async def rest_server_handler(websocket):
"""Handles routes of the form: /main. This route is called by the rest-server
whenever a worker needs to be pinged (to ask it to check in). The body of the
message is the worker id to ping. This function sends a message to the worker
with that worker id through an appropriate websocket.
class TimedLock:
"""A lock that gets automatically released after timeout_seconds.
"""
# Got a message from the rest server.
worker_id = await websocket.recv()
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")

try:
worker_ws = worker_to_ws[worker_id]
await worker_ws.send(worker_id)
except KeyError:
logger.error(f"Websocket not found for worker: {worker_id}")
def __init__(self, timeout_seconds: float = 60):
self._lock = threading.Lock()
self._time_since_locked: float
self._timeout: float = timeout_seconds

def acquire(self, blocking=True, timeout=-1):
acquired = self._lock.acquire(blocking, timeout)
if acquired:
self._time_since_locked = time.time()
return acquired

def locked(self):
return self._lock.locked()

def release(self):
self._lock.release()

def timeout(self):
return time.time() - self._time_since_locked > self._timeout

def release_if_timeout(self):
if self.locked() and self.timeout():
self.release()

async def worker_handler(websocket, worker_id):
"""Handles routes of the form: /worker/{id}. This route is called when
a worker first connects to the ws-server, creating a connection that can
be used to ask the worker to check-in later.
"""
# runs on worker connect
worker_to_ws[worker_id] = websocket
logger.warning(f"Connected to worker {worker_id}!")

worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict(
dict
) # Maps worker ID to socket ID to websocket
worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict(
dict
) # Maps worker ID to socket ID to lock
ACK = b'a'
logger = logging.getLogger(__name__)
manager = CodaLabManager()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodaLabManager is used to give access to the SQL database to websocket server.

bundle_model = manager.model()
worker_model = manager.worker_model()
server_secret = os.getenv("CODALAB_SERVER_SECRET")


async def send_to_worker_handler(server_websocket, worker_id):
"""Handles routes of the form: /send_to_worker/{worker_id}. This route is called by
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
"""
# Authenticate server.
received_secret = await server_websocket.recv()
if received_secret != server_secret:
logger.warning("Server unable to authenticate.")
await server_websocket.close(1008, "Server unable to authenticate.")
return

# Check if any websockets available
if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0:
logger.warning(f"No websockets currently available for worker {worker_id}")
await server_websocket.close(
1011, f"No websockets currently available for worker {worker_id}"
)
return

# Send message from server to worker.
for socket_id, worker_websocket in random.sample(
worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id])
):
if worker_to_lock[worker_id][socket_id].acquire(blocking=False):
data = await server_websocket.recv()
await worker_websocket.send(data)
await server_websocket.send(ACK)
worker_to_lock[worker_id][socket_id].release()
return

logger.warning(f"All websockets for worker {worker_id} are currently busy.")
await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.")


async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None:
"""Handles routes of the form: /worker_connect/{worker_id}/{socket_id}.
This route is called when a worker first connects to the ws-server, creating
a connection that can be used to ask the worker to check-in later.
"""
# Authenticate worker.
access_token = await websocket.recv()
user_id = worker_model.get_user_id_for_worker(worker_id=worker_id)
authenticated = bundle_model.access_token_exists_for_user(
'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should avoid hard-coding the client-id to be codalab_worker_client?

)
logger.error(f"AUTHENTICATED: {authenticated}")
if not authenticated:
logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.")
await websocket.close(
1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate."
)
return

# Establish a connection with worker and keep it alive.
worker_to_ws[worker_id][socket_id] = websocket
worker_to_lock[worker_id][socket_id] = TimedLock()
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")
while True:
try:
await asyncio.wait_for(websocket.recv(), timeout=60)
worker_to_lock[worker_id][
socket_id
].release_if_timeout() # Failsafe in case not released
except asyncio.futures.TimeoutError:
pass
except websockets.exceptions.ConnectionClosed:
logger.error(f"Socket connection closed with worker {worker_id}.")
logger.warning(f"Socket connection closed with worker {worker_id}.")
break


ROUTES = (
(r'^.*/main$', rest_server_handler),
(r'^.*/worker/(.+)$', worker_handler),
)
del worker_to_ws[worker_id][socket_id]
del worker_to_lock[worker_id][socket_id]
logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections")


async def ws_handler(websocket, *args):
"""Handler for websocket connections. Routes websockets to the appropriate
route handler defined in ROUTES."""
logger.warning(f"websocket handler, path: {websocket.path}.")
ROUTES = (
(r'^.*/send_to_worker/(.+)$', send_to_worker_handler),
(r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler),
)
logger.info(f"websocket handler, path: {websocket.path}.")
for (pattern, handler) in ROUTES:
match = re.match(pattern, websocket.path)
if match:
return await handler(websocket, *match.groups())
assert False
return await websocket.close(1011, f"Path {websocket.path} is not valid.")


async def async_main():
"""Main function that runs the websocket server."""
parser = argparse.ArgumentParser()
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
parser.add_argument(
'--port', help='Port to run the server on.', type=int, required=False, default=2901
)
args = parser.parse_args()
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
async with websockets.serve(ws_handler, "0.0.0.0", args.port):
Expand Down
9 changes: 8 additions & 1 deletion codalab/lib/codalab_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def ws_server(self):
ws_port = self.config['ws-server']['ws_port']
return f"ws://ws-server:{ws_port}"

@property # type: ignore
@cached
def server_secret(self):
return os.getenv("CODALAB_SERVER_SECRET")
epicfaace marked this conversation as resolved.
Show resolved Hide resolved

@property # type: ignore
@cached
def worker_socket_dir(self):
Expand Down Expand Up @@ -380,7 +385,9 @@ def model(self):

@cached
def worker_model(self):
return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server)
return WorkerModel(
self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret
)

@cached
def upload_manager(self):
Expand Down
14 changes: 5 additions & 9 deletions codalab/lib/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
read_args = {'type': 'get_target_info', 'depth': depth}
self._send_read_message(worker, response_socket_id, target, read_args)
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
result = self._worker_model.get_json_message(sock, 60)
result = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
if result is None: # dead workers are a fact of life now
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
raise NotFoundError(
Expand Down Expand Up @@ -365,9 +365,7 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
'path': target.subpath,
'read_args': read_args,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
if not self._worker_model.send_json_message(message, worker['worker_id']):
logging.info('Unable to reach worker')

def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
Expand All @@ -378,21 +376,19 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
'port': port,
'message': message,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
if not self._worker_model.send_json_message(message, worker['worker_id']):
logging.info('Unable to reach worker')

def _get_read_response_stream(self, response_socket_id):
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
header_message = self._worker_model.get_json_message(sock, 60)
header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
precondition(header_message is not None, 'Unable to reach worker')
if 'error_code' in header_message:
raise http_error_to_exception(
header_message['error_code'], header_message['error_message']
)

fileobj = self._worker_model.get_stream(sock, 60)
fileobj = self._worker_model.recv_stream(sock, 60)
precondition(fileobj is not None, 'Unable to reach worker')
return fileobj

Expand Down
21 changes: 20 additions & 1 deletion codalab/model/bundle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2882,7 +2882,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None):

return OAuth2Token(self, **row)

def find_oauth2_token(self, client_id, user_id, expires_after):
def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()):
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
Expand All @@ -2901,6 +2901,25 @@ def find_oauth2_token(self, client_id, user_id, expires_after):

return OAuth2Token(self, **row)

def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool:
epicfaace marked this conversation as resolved.
Show resolved Hide resolved
"""Check that the provided access_token exists in the database for the provided user_id.
"""
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
.where(
and_(
oauth2_token.c.client_id == client_id,
oauth2_token.c.user_id == user_id,
oauth2_token.c.access_token == access_token,
oauth2_token.c.expires > datetime.datetime.utcnow(),
)
)
.limit(1)
).fetchone()

return row is not None

def save_oauth2_token(self, token):
with self.engine.begin() as connection:
result = connection.execute(oauth2_token.insert().values(token.columns))
Expand Down
Loading