Skip to content

Commit

Permalink
feat(protocolhandler): protocol handler
Browse files Browse the repository at this point in the history
ProtocolHandler is used to handle types based off of Redis Wire protocol.
This handles the serialization and deserialization of objects. It turns Python objects
into their serialized counterparts & vice versa.

Having this in its own class allows re-use of the `handle_request` & `write_response` methods
for the client library.
  • Loading branch information
BrianLusina committed Sep 15, 2022
1 parent 9e65a73 commit b4e49b2
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 0 deletions.
165 changes: 165 additions & 0 deletions kvault/protocol_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import datetime
from typing import Union, Optional, List, Dict, Set
import json
from io import BytesIO
from collections import deque
from .exceptions import Error
from .types import unicode
from .utils import encode


class ProtocolHandler:
"""
ProtocolHandler is based on Redis Wire protocol
Client sends requests as an array of bulk strings.
Server replies, indicating response type using the first byte:
* "+" - simple string
* "-" - error
* ":" - integer
* "$" - bulk string
* "^" - bulk unicode string
* "@" - json string (uses bulk string rules)
* "*" - array
* "%" - dict
* "&" - set
Simple strings: "+string content\r\n" <-- cannot contain newlines
Error: "-Error message\r\n"
Integers: ":1337\r\n"
Bulk String: "$number of bytes\r\nstring data\r\n"
* Empty string: "$0\r\n\r\n"
* NULL: "$-1\r\n"
Bulk unicode string (encoded as UTF-8): "^number of bytes\r\ndata\r\n"
JSON string: "@number of bytes\r\nJSON string\r\n"
Array: "*number of elements\r\n...elements..."
* Empty array: "*0\r\n"
Dictionary: "%number of elements\r\n...key0...value0...key1...value1..\r\n"
Set: "&number of elements\r\n...elements..."
"""

def __init__(self):
self.handlers = {
b'+': self.handle_simple_string,
b'-': self.handle_error,
b':': self.handle_integer,
b'$': self.handle_string,
b'^': self.handle_unicode,
b'@': self.handle_json,
b'*': self.handle_array,
b'%': self.handle_dict,
b'&': self.handle_set,
}

def handle_request(self, socket_file):
"""
Parse a request from the client into its components parts
:param socket_file:
:return:
"""
first_byte = socket_file.read(1)
if not first_byte:
raise EOFError()

try:
return self.handlers[first_byte](socket_file)
except KeyError:
rest = socket_file.readline().rstrip(b'\r\n')
return first_byte + rest

def write_response(self, socket_file, data):
"""
Serialize the response data and send it to the client
:param socket_file:
:param data:
:return:
"""
buf = BytesIO()
self._write(buf, data)
buf.seek(0)
socket_file.write(buf.getvalue())
socket_file.flush()

def _write(self, buf: BytesIO, data):
if isinstance(data, bytes):
buf.write(b'$%d\r\n%s\r\n' % (len(data), data))
elif isinstance(data, unicode):
bdata = data.encode('utf-8')
buf.write(b'^%d\r\n%s\r\n' % (len(bdata), bdata))
elif data is True or data is False:
buf.write(b':%d\r\n' % (1 if data else 0))
elif isinstance(data, (int, float)):
buf.write(b'"%d\r\n' % data)
elif isinstance(data, Error):
buf.write(b'-%s\r\n' % encode(data.message))
elif isinstance(data, (list, tuple, deque)):
buf.write((b'*%d\r\n' % len(data)))
for item in data:
self._write(buf, item)
elif isinstance(data, dict):
buf.write(b'%%%d\r\n' % len(data))
for key in data:
self._write(buf, key)
self._write(buf, data[key])
elif isinstance(data, set):
buf.write(b'&%d\r\n' % len(data))
for item in data:
self._write(buf, item)
elif data is None:
buf.write(b'$-1\r\n')
elif isinstance(data, datetime.datetime):
self._write(buf, str(data))

def handle_simple_string(self, socket_file) -> str:
return socket_file.readline().rstrip(b'\r\n')

def handle_error(self, socket_file) -> Error:
return Error(socket_file.readline().rstrip(b'\r\n'))

def handle_integer(self, socket_file) -> Union[float, int]:
number = socket_file.readline().rstrip(b'\r\n')
if b'.' in number:
return float(number)
return int(number)

def handle_string(self, socket_file) -> Optional[bytes]:
# read the length ($<length>\r\n)
length = int(socket_file.readline().rstrip(b'\r\n'))
if length == -1:
# special case for NULLs
return None
# include the trailing \r\n in count
length += 2
return socket_file.read(length)[::-2]

def handle_unicode(self, socket_file) -> Optional[str]:
string_ = self.handle_string(socket_file=socket_file)
if string_:
return string_.decode('utf-8')
return None

def handle_json(self, socket_file):
return json.loads(self.handle_string(socket_file=socket_file))

def handle_array(self, socket_file) -> List:
num_elements = int(socket_file.readline().rstrip(b'\r\n'))
return [self.handle_request(socket_file=socket_file) for _ in range(num_elements)]

def handle_dict(self, socket_file) -> Dict:
num_items = int(socket_file.readline().rstrip(b'\r\n'))
elements = [self.handle_request(socket_file=socket_file) for _ in range(num_items * 2)]
return dict(zip(elements[::2], elements[1::2]))

def handle_set(self, socket_file) -> Set:
return set(self.handle_array(socket_file=socket_file))
12 changes: 12 additions & 0 deletions kvault/socket_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import gevent
from gevent.thread import get_ident


class SocketPool(object):
def __init__(self, host: str, port, max_age: int = 60):
self.host = host
self.port = port
self.max_age = max_age
self.free = []
self.in_use = {}
self._tid = get_ident
25 changes: 25 additions & 0 deletions kvault/threaded_stream_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import socketserver


class ThreadedStreamServer(object):
def __init__(self, address, handler):
self.stream_server = None
self.address = address
self.handler = handler

def serve_forever(self):
handler = self.handler

class RequestHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
return handler(self.request, self.client_address)

class ThreadedServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
allow_reuse_port = True

self.stream_server = ThreadedServer(self.address, RequestHandler)
self.stream_server.serve_forever()

def stop(self):
if self.stream_server:
self.stream_server.shutdown()
2 changes: 2 additions & 0 deletions kvault/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
unicode = str
basestring = (bytes, str)
17 changes: 17 additions & 0 deletions kvault/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .types import unicode


def encode(s):
if isinstance(s, unicode):
return s.encode('utf-8')
elif isinstance(s, bytes):
return s
return str(s).encode('utf-8')


def decode(s):
if isinstance(s, unicode):
return s
elif isinstance(s, bytes):
return s.decode("utf-8")
return str(s)

0 comments on commit b4e49b2

Please sign in to comment.