Skip to content

Commit

Permalink
#2121 / #1252 / #1568: SSL peek support for python3
Browse files Browse the repository at this point in the history
git-svn-id: https://xpra.org/svn/Xpra/trunk@21877 3bb7dfac-3a0b-4e04-842a-767bc560f471
  • Loading branch information
totaam committed Feb 25, 2019
1 parent 66359fc commit 89b2580
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 60 deletions.
139 changes: 81 additions & 58 deletions src/xpra/net/bytestreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
if SOCKET_CORK:
try:
assert socket.TCP_CORK>0
except (AttributeError, AssertionError) as e:
except (AttributeError, AssertionError) as cork_e:
log.warn("Warning: unable to use TCP_CORK on %s", sys.platform)
log.warn(" %s", e)
log.warn(" %s", cork_e)
SOCKET_CORK = False
SOCKET_NODELAY = envbool("XPRA_SOCKET_NODELAY", None)
VSOCK_TIMEOUT = envint("XPRA_VSOCK_TIMEOUT", 5)
SOCKET_TIMEOUT = envint("XPRA_SOCKET_TIMEOUT", 20)
SSL_PEEK = PYTHON2 and envbool("XPRA_SSL_PEEK", True)
SSL_PEEK = envbool("XPRA_SSL_PEEK", True)
#this is more proper but would break the proxy server:
SOCKET_SHUTDOWN = envbool("XPRA_SOCKET_SHUTDOWN", False)

Expand Down Expand Up @@ -108,8 +108,8 @@ def can_retry(e):

abort = ABORT.get(code, code)
if abort is not None:
errno = getattr(e, "errno", None)
log("can_retry: %s, args=%s, errno=%s, code=%s, abort=%s", type(e), e.args, errno, code, abort)
err = getattr(e, "errno", None)
log("can_retry: %s, args=%s, errno=%s, code=%s, abort=%s", type(e), e.args, err, code, abort)
raise ConnectionClosedException(e)
if isinstance(e, CLOSED_EXCEPTIONS):
raise ConnectionClosedException(e)
Expand Down Expand Up @@ -461,58 +461,81 @@ def get_socket_options(sock, level, options):
return opts


SSLSocket = None
if SSL_PEEK:
try:
#this wrapper class allows us to override the normal ssl.Socket
#class so that we can fake peek() support by actually reading from the socket
#and caching the result.
class SSLSocket(socket._socketobject):

def __init__(self, sock):
socket._socketobject.__init__(self, _sock=sock)
#patch recv:
self.saved_recv = getattr(self, "recv")
setattr(self, "recv", self._recv)
self.saved_makefile = getattr(self, "makefile")
setattr(self, "makefile", self._makefile)
self.peeked = b""

def _recv(self, bufsize, flags=0):
#log("_recv(%s, %#x) peeked=%i bytes", bufsize, flags, len(self.peeked))
peek = flags & socket.MSG_PEEK
if self.peeked:
#we have peek data aleady
if bufsize<len(self.peeked):
r = self.peeked[:bufsize]
class SSLPeekFile(object):
def __init__(self, fileobj, peeked, update_peek):
self.fileobj = fileobj
self.peeked = peeked
self.update_peek = update_peek

def __getattr__(self, attr):
if attr=="readline" and self.peeked:
return self.readline
return getattr(self.fileobj, attr)

def readline(self, limit=-1):
if self.peeked:
newline = self.peeked.find(b"\n")
peeked = self.peeked
l = len(peeked)
if newline==-1:
if limit==-1 or limit>l:
#we need to read more until we hit a newline:
if limit==-1:
more = self.fileobj.readline(limit)
else:
r = self.peeked
if not peek:
#remove what we return from peek buffer:
if bufsize<len(self.peeked):
self.peeked = self.peeked[bufsize:]
else:
self.peeked = b""
return r
r = self.saved_recv(bufsize, flags & (0xffffffff ^ socket.MSG_PEEK))
if peek:
self.peeked = r
return r

def _makefile(self, mode='r', bufsize=-1):
from socket import _fileobject
fo = _fileobject(self, mode, bufsize)
return fo

def __repr__(self):
return "SSLSocket(%s)" % self._sock

except Exception as e:
ssllog = Logger("ssl")
ssllog("ssl peek", exc_info=True)
ssllog.warn("Warning: unable to override socket object")
ssllog.warn(" SSL peek support will not be available")
ssllog.warn(" %s", e)
more = self.fileobj.readline(limit-len(self.peeked))
self.peeked = b""
self.update_peek(self.peeked)
return peeked+more
read = limit
else:
if limit<0 or limit>=newline:
read = newline+1
else:
read = limit
self.peeked = peeked[read:]
self.update_peek(self.peeked)
return peeked[:read]
return self.fileobj.readline(limit)

class SSLSocketWrapper(object):
def __init__(self, sock):
self.socket = sock
self.peeked = b""

def __getattr__(self, attr):
if attr=="makefile":
return self.makefile
if attr=="recv":
return self.recv
return getattr(self.socket, attr)

def makefile(self, mode, bufsize=None):
fileobj = self.socket.makefile(mode, bufsize)
if self.peeked and mode and mode.startswith("r"):
return SSLPeekFile(fileobj, self.peeked, self._update_peek)
return fileobj

def _update_peek(self, peeked):
self.peeked = peeked

def recv(self, bufsize, flags=0):
if flags & socket.MSG_PEEK:
l = len(self.peeked)
if l>=bufsize:
log("patched_recv() peeking using existing data: %i bytes", bufsize)
return self.peeked[:bufsize]
v = self.socket.recv(bufsize-l)
if v:
log("patched_recv() peeked more: %i bytes", len(v))
self.peeked += v
return self.peeked
if self.peeked:
peeked = self.peeked[:bufsize]
self.peeked = self.peeked[bufsize:]
log("patched_recv() non peek, returned already read data")
return peeked
return self.socket.recv(bufsize, flags)


class SSLSocketConnection(SocketConnection):
Expand All @@ -532,8 +555,8 @@ def can_retry(self, e):
return SocketConnection.can_retry(self, e)

def enable_peek(self):
if SSLSocket:
self._socket = SSLSocket(self._socket)
assert not isinstance(self._socket, SSLSocketWrapper)
self._socket = SSLSocketWrapper(self._socket)

def get_info(self):
i = SocketConnection.get_info(self)
Expand Down
6 changes: 4 additions & 2 deletions src/xpra/server/server_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,8 @@ def start_http_socket(self, socktype, conn, is_ssl=False, peek_data=""):
if peek_data:
line1 = peek_data.splitlines()[0]
http_proto = "http"+["","s"][int(is_ssl)]
netlog("start_http_socket(%s, %s, %s, ..) http proto=%s, line1=%r", socktype, conn, is_ssl, http_proto, line1)
netlog("start_http_socket(%s, %s, %s, ..) http proto=%s, line1=%r",
socktype, conn, is_ssl, http_proto, bytestostr(line1))
if line1.startswith(b"GET ") or line1.startswith(b"POST "):
parts = bytestostr(line1).split(" ")
httplog("New %s %s request received from %s for '%s'", http_proto, parts[0], frominfo, parts[1])
Expand Down Expand Up @@ -1206,7 +1207,8 @@ def new_websocket_client(wsh):
self.make_protocol(newsocktype, conn, WebSocketProtocol)
scripts = self.get_http_scripts()
conn.socktype = "wss" if is_ssl else "ws"
WebSocketRequestHandler(sock, frominfo, new_websocket_client, self._www_dir, self._http_headers_dir, scripts)
WebSocketRequestHandler(sock, frominfo, new_websocket_client,
self._www_dir, self._http_headers_dir, scripts)
return
except (IOError, ValueError) as e:
httplog("start_http%s", (socktype, conn, is_ssl, req_info, frominfo), exc_info=True)
Expand Down

0 comments on commit 89b2580

Please sign in to comment.