Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use websockets for json communication #4490

Merged
merged 41 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3221c90
first commit
Jun 14, 2023
9c0ec8d
Fix bugs
AndrewJGaut Jun 14, 2023
a139498
Fix more bugsz
AndrewJGaut Jun 14, 2023
b226f40
slight change
AndrewJGaut Jun 14, 2023
93f3e62
minor bug fixes
Jun 14, 2023
0db9264
It's working quite well now
AndrewJGaut Jun 28, 2023
1a65356
merge conflict
AndrewJGaut Jun 28, 2023
ae3e3be
Update to solve that last bug. I think it should work now
AndrewJGaut Jun 28, 2023
85774a4
bug fix that was causing messages to be send to tall worker threads
AndrewJGaut Jun 29, 2023
81064a7
Fix unittests issue and fix formatting
AndrewJGaut Jul 3, 2023
5326420
Minor change to increase robustness of message sending
AndrewJGaut Jul 3, 2023
34550f0
Some more formatting changes
AndrewJGaut Jul 3, 2023
57aad3d
Add in worker auth and use wss rather than ws for websocket URLs to m…
AndrewJGaut Jul 12, 2023
42127e5
Added in server auth with secret. Aslo still need to test (still on p…
AndrewJGaut Jul 12, 2023
aa37b0e
Fixed issues and got auth working. Now, I'll work on returning error …
AndrewJGaut Jul 12, 2023
3e71d67
Add in tests for authentication functionality (for worker and server)
AndrewJGaut Jul 12, 2023
7a0bd70
Slight cleanup to data sending code
AndrewJGaut Jul 13, 2023
cbd1505
Adding in ssl certification
AndrewJGaut Jul 13, 2023
4eb3d7b
Add in SSL stuff for worker; still testing on dev
AndrewJGaut Jul 14, 2023
aa74981
Revert "Add in SSL stuff for worker; still testing on dev"
AndrewJGaut Jul 14, 2023
123fc7e
Revert "Adding in ssl certification"
AndrewJGaut Jul 14, 2023
6efc0cf
Fixed formatting
AndrewJGaut Jul 15, 2023
35bdab0
Very minor formatting change to ignore one line for MyPy
AndrewJGaut Jul 15, 2023
9f5417c
Another minor formatting change
AndrewJGaut Jul 15, 2023
f0ec0f1
Merge branch 'master' into use-websockets-for-json-communication
AndrewJGaut Jul 16, 2023
baf9121
add exponential backoff to see if that fixes dev issue
AndrewJGaut Jul 16, 2023
6222950
Merge branch 'use-websockets-for-json-communication' of github.com:co…
AndrewJGaut Jul 16, 2023
100d536
Added code to actually detect worker disconnections now so that some …
AndrewJGaut Jul 16, 2023
2a82155
Make sockets get looped over in random order to help distribute load
AndrewJGaut Jul 16, 2023
5dabf6d
Clean up ws-server and delete a Dataclass I was using previously
AndrewJGaut Jul 17, 2023
0c57229
a few more minor changes
AndrewJGaut Jul 17, 2023
e34fb72
Make websocket locks more robust and improve error messaging
AndrewJGaut Jul 17, 2023
42819a6
More permissible retries in case of other errors (e.g. like 1013). Wi…
AndrewJGaut Jul 17, 2023
8e868a4
minor change to have a different error message if worker doesn't yet …
AndrewJGaut Jul 17, 2023
f36ed3a
Rename send_json and send_json_message_with_sock
AndrewJGaut Aug 16, 2023
4ed30a2
Rearrange worker_model to minimize diff
AndrewJGaut Aug 16, 2023
d115a96
Final changes
AndrewJGaut Aug 16, 2023
5ac4853
Fix formatting
AndrewJGaut Aug 16, 2023
8acedff
Minor changes to get auth working again and to robustly return an err…
AndrewJGaut Aug 23, 2023
3d25b77
Merge branch 'master' of github.com:codalab/codalab-worksheets into u…
AndrewJGaut Sep 10, 2023
bbe9402
Merge in master and make some minor changes
AndrewJGaut Sep 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 161 additions & 32 deletions codalab/bin/ws_server.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,204 @@
# This is the real ws-server, basically.
# Main entry point for CodaLab cl-ws-server.
import argparse
import json
import asyncio
from collections import defaultdict
import logging
import re
from typing import Any, Dict
from typing import Any, Dict, List
import websockets
from dataclasses import dataclass
import threading
import time

ACK=b'a'

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')

worker_to_ws: Dict[str, Any] = {}

"""
TODO!!! WE NEED TO ADD A SECRET OR SOME SORT OF AUTH FOR THE WS SERVER ENDPOITNS
AndrewJGaut marked this conversation as resolved.
Show resolved Hide resolved
AndrewJGaut marked this conversation as resolved.
Show resolved Hide resolved
Otherwise, people could create a custom local worker build and wreak havoc on the ws server...
Note that this was already an issue, though; people could've just hit the checkpoint endpoint...
"""

async def rest_server_handler(websocket):
"""Handles routes of the form: /main. This route is called by the rest-server
whenever a worker needs to be pinged (to ask it to check in). The body of the
message is the worker id to ping. This function sends a message to the worker
with that worker id through an appropriate websocket.
@dataclass
class WS:
AndrewJGaut marked this conversation as resolved.
Show resolved Hide resolved
"""
# Got a message from the rest server.
worker_id = await websocket.recv()
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")

try:
worker_ws = worker_to_ws[worker_id]
await worker_ws.send(worker_id)
except KeyError:
logger.error(f"Websocket not found for worker: {worker_id}")
Stores websocket object and whether or not the websocket is available.
TODO: give this a better, less confusing name.
"""
_ws: Any = None
_is_available: bool = True
_lock: threading.Lock = threading.Lock()
_timeout: float = 86400
_last_use: float = None

@property
def ws(self):
return self._ws

@property
def lock(self):
return self._lock

@property
def is_available(self):
return self._is_available

@is_available.setter
def is_available(self, value):
self._is_available = value

@property
def timeout(self):
return self._timeout

@property
def last_use(self):
return self._last_use

@last_use.setter
def last_use(self, value):
self._last_use = value

worker_to_ws: Dict[str, Dict[str, WS]] = defaultdict(dict) # Maps worker to list of its websockets (since each worker has a pool of connections)
server_worker_to_ws: Dict[str, Dict[str, WS]] = defaultdict(dict) # Map the rest-server websocket connection to the corresponding worker socket connection.

async def connection_handler(websocket, worker_id):
"""
Handles routes of the form: /server/connect/worker_id. This route is called by the rest-server
(or bundle-manager) when they want a connection to a worker.
Returns the id of the socket to connect to, which will be used in later requests.
"""
logger.debug(f"Got a message from the rest server, to connect to worker: {worker_id}.")
socket_id = None
#logger.error([ws.is_available for _,ws in worker_to_ws[worker_id].items()])
for s_id, ws in worker_to_ws[worker_id].items():
logger.debug("about to lock")
with ws.lock:
logger.debug("locked")
if ws.is_available or time.time() - ws.last_use >= ws.timeout:
logger.debug("available")
ws.last_use = time.time()
socket_id = s_id
worker_to_ws[worker_id][socket_id].is_available = False
logger.debug("breaking")
break

logger.info(f"For worker {worker_id}, sending server socket ID {socket_id}")
if not socket_id:
logger.error(f"No socket ids available for worker {worker_id}")
#import pdb; pdb.set_trace()
await websocket.send(json.dumps({'socket_id': socket_id}))
logger.debug("Sent.")


async def disconnection_handler(websocket, worker_id, socket_id):
"""
Handles routes of the form: /server/connect/worker_id/socket_id. This route is called by the rest-server
(or bundle-manager) when they want a connection to a worker.
Returns the id of the socket to connect to, which will be used in later requests.
"""
with worker_to_ws[worker_id][socket_id].lock:
logger.info(f"For worker {worker_id}, disconnecting socket ID {socket_id}")
if worker_to_ws[worker_id][socket_id].is_available:
# For now, just log the error.
logging.error("Available socket set for disconnection")
worker_to_ws[worker_id][socket_id].is_available = True


async def exchange(from_ws, to_ws, worker_id, socket_id):
# Send and receive all data until connection closes.
# (We may be streaming a file, in which case we need to receive and then send
# lots of chunks)
while True:
# Don't use async for so we can avoid two couroutines waiting on same socket.
try:
data = await from_ws.recv()
await to_ws.send(data)
except websockets.exceptions.ConnectionClosed:
break
# ahhh... this won't work. When the message is sent, it's actually buffered at the client, so it being sent doesn't mean it was received...
# Shoot... I wonder if tehre's a way around htis...
# In fact, seee here: https://stackoverflow.com/questions/46549892/does-websocket-send-guarantee-consumption
# we need to receive an ACK from the to_ws lol and then send it to the from_ws. yuck lol
# But that's not too hard; we can do that pretty easily.
# this will create lot sof extra traffic, but that's OK

# Now, we're getting the "recv() called by two coroutines for same websocket"
# It makes sense: the thread sending to ws=websocket(worker_id, socket_id) is calling data = await ws.recv()
# and the recv caller is calling async for data in ws: and so both are calling recv().
# That's an issue that's not solvable, I don't think... shoot.

# The answer in this case might be to keep two versions of websockets per worker.
# So, one for the worker and one for the server (for each worker_id, socket_id combination).
# Why do this? So that we can wait and send properly... It'd be very annoying, though, for sure
# I don't like it. Is there any better way to do this?
# I don't think so... I think this works better, unfortunately. It's kind of gross, but that's alright.
# We'll need a separate send and recv handler now for server and worker... Kind of annoying.
# Might be able to get by it with some clever instantiation... oh well

# No, I think we can avoid that.

async def send_handler(websocket, worker_id, socket_id):
"""Handles routes of the form: /send/{worker_id}/{socket_id}. This route is called by
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
"""
with worker_to_ws[worker_id][socket_id].lock:
worker_to_ws[worker_id][socket_id].ws.last_use = time.time()
await exchange(websocket, worker_to_ws[worker_id][socket_id].ws, worker_id, socket_id)

async def recv_handler(websocket, worker_id, socket_id):
"""Handles routes of the form: /recv/{worker_id}/{socket_id}. This route is called by
the rest-server or bundle-manager when either wants to receive a message/stream from the worker.
"""
with worker_to_ws[worker_id][socket_id].lock:
worker_to_ws[worker_id][socket_id].ws.last_use = time.time()
await exchange(worker_to_ws[worker_id][socket_id].ws, websocket, worker_id, socket_id)

async def worker_handler(websocket, worker_id):
"""Handles routes of the form: /worker/{id}. This route is called when
async def worker_handler(websocket, worker_id, socket_id):
"""Handles routes of the form: /worker/{worker_id}/{socket_id}. This route is called when
a worker first connects to the ws-server, creating a connection that can
be used to ask the worker to check-in later.
"""
# runs on worker connect
worker_to_ws[worker_id] = websocket
logger.warning(f"Connected to worker {worker_id}!")
worker_to_ws[worker_id][socket_id] = WS(websocket)
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")

# keep connection alive.
while True:
try:
await asyncio.wait_for(websocket.recv(), timeout=60)
await asyncio.sleep(60)
except asyncio.futures.TimeoutError:
pass
except websockets.exceptions.ConnectionClosed:
logger.error(f"Socket connection closed with worker {worker_id}.")
break


ROUTES = (
(r'^.*/main$', rest_server_handler),
(r'^.*/worker/(.+)$', worker_handler),
(r'^.*/send/(.+)/(.+)$', send_handler),
(r'^.*/recv/(.+)/(.+)$', recv_handler),
(r'^.*/server/connect/(.+)$', connection_handler),
(r'^.*/server/disconnect/(.+)/(.+)$', disconnection_handler),
(r'^.*/worker/(.+)/(.+)$', worker_handler),
)


async def ws_handler(websocket, *args):
"""Handler for websocket connections. Routes websockets to the appropriate
route handler defined in ROUTES."""
logger.warning(f"websocket handler, path: {websocket.path}.")
logger.debug(f"websocket handler, path: {websocket.path}.")
for (pattern, handler) in ROUTES:
match = re.match(pattern, websocket.path)
if match:
return await handler(websocket, *match.groups())
assert False


async def async_main():
"""Main function that runs the websocket server."""
parser = argparse.ArgumentParser()
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
parser.add_argument('--port', help='Port to run the server on.', type=int, required=False, default=2901)
args = parser.parse_args()
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
async with websockets.serve(ws_handler, "0.0.0.0", args.port):
Expand All @@ -84,4 +213,4 @@ def main():


if __name__ == '__main__':
main()
main()
16 changes: 8 additions & 8 deletions codalab/lib/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
read_args = {'type': 'get_target_info', 'depth': depth}
self._send_read_message(worker, response_socket_id, target, read_args)
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
result = self._worker_model.get_json_message(sock, 60)
result = self._worker_model.recv_json_message_with_sock(sock, 60)
if result is None: # dead workers are a fact of life now
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
raise NotFoundError(
Expand Down Expand Up @@ -365,8 +365,8 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
'path': target.subpath,
'read_args': read_args,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
if not self._worker_model.connect_and_send_json(
message, worker['worker_id']
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

Expand All @@ -378,21 +378,21 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
'port': port,
'message': message,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
if not self._worker_model.connect_and_send_json(
message, worker['worker_id']
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

def _get_read_response_stream(self, response_socket_id):
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
header_message = self._worker_model.get_json_message(sock, 60)
header_message = self._worker_model.recv_json_message_with_sock(sock, 60)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to get_json_message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do want to keep the with_sock portion (I changed it to with_unix_socket in this newest version) since it helps make clear that it's using AF_Unix sockets. However, I could change recv back to get -- I thought recv was clearer since that's typically the API for getting data from sockets.

Let me know. I can definitely change this.

precondition(header_message is not None, 'Unable to reach worker')
if 'error_code' in header_message:
raise http_error_to_exception(
header_message['error_code'], header_message['error_message']
)

fileobj = self._worker_model.get_stream(sock, 60)
fileobj = self._worker_model.recv_stream(sock, 60)
precondition(fileobj is not None, 'Unable to reach worker')
return fileobj

Expand All @@ -416,4 +416,4 @@ def __getattr__(self, attr):

def close(self):
self._fileobj.close()
self._worker_model.deallocate_socket(self._socket_id)
self._worker_model.deallocate_socket(self._socket_id)
Loading