diff --git a/keepassc/conn.py b/keepassc/conn.py index cdc7807..010df64 100644 --- a/keepassc/conn.py +++ b/keepassc/conn.py @@ -45,6 +45,9 @@ def receive(conn): while True: try: received = conn.recv(16) + if not received: + logging.error("No data received") + break except: raise if b'\xDE\xAD\xE1\x1D' in received: diff --git a/keepassc/server.py b/keepassc/server.py index fc5ae4e..7b5aec7 100644 --- a/keepassc/server.py +++ b/keepassc/server.py @@ -43,13 +43,13 @@ def __call__(self, *args): self.func(args[0], args[1]) self.lock = False break - + class Server(Daemon): """The KeePassC server daemon""" def __init__(self, pidfile, loglevel, logfile, address = None, port = 50002, db = None, password = None, keyfile = None, - tls = False, tls_dir = None, tls_port = 50003, + tls = False, tls_dir = None, tls_port = 50003, tls_req = False): Daemon.__init__(self, pidfile) @@ -68,7 +68,7 @@ def __init__(self, pidfile, loglevel, logfile, address = None, if db is None: print('Need a database path') sys.exit(1) - + self.db_path = realpath(expanduser(db)) # To use this idiom only once, I store the keyfile path @@ -110,7 +110,7 @@ def __init__(self, pidfile, loglevel, logfile, address = None, self.net_sock = None self.tls_sock = None self.tls_req = tls_req - + if tls is True or tls_req is True: self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) cert = join(tls_dir, "servercert.pem") @@ -165,20 +165,20 @@ def __init__(self, pidfile, loglevel, logfile, address = None, def check_password(self, password, keyfile): """Check received password""" - + master = get_key(password, keyfile, True) remote_final = transform_key(master, self.db._transf_randomseed, - self.db._final_randomseed, + self.db._final_randomseed, self.db._key_transf_rounds) master = get_key(self.db.password, self.db.keyfile) final = transform_key(master, self.db._transf_randomseed, - self.db._final_randomseed, + self.db._final_randomseed, self.db._key_transf_rounds) return (remote_final == final) def run(self): """Overide Daemon.run() and provide socets""" - + try: local_thread = threading.Thread(target=self.handle_non_tls, args=(self.sock,)) @@ -205,7 +205,7 @@ def handle_non_tls(self, sock): logging.error(err.__str__()) else: logging.info('Connection from '+client[0]+':'+str(client[1])) - client_thread = threading.Thread(target=self.handle_client, + client_thread = threading.Thread(target=self.handle_client, args=(conn,client,)) client_thread.daemon = True client_thread.start() @@ -222,7 +222,7 @@ def handle_tls(self): logging.error(err.__str__()) else: logging.info('Connection from '+client[0]+':'+str(client[1])) - client_thread = threading.Thread(target=self.handle_client, + client_thread = threading.Thread(target=self.handle_client, args=(conn, client,)) client_thread.daemon = True client_thread.start() @@ -357,7 +357,7 @@ def create_entry(self, conn, parts): self.db.save() self.send_db(conn, []) - + @waitDecorator def delete_group(self, conn, parts): group_id = int(parts.pop(0)) @@ -388,7 +388,7 @@ def delete_entry(self, conn, parts): time = datetime(int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3]), int(parts[4]), int(parts[5])) time = time.timetuple() - + for i in self.db.entries: if i.uuid == uuid: if self.check_last_mod(i, time) is True: @@ -457,7 +457,7 @@ def move_entry(self, conn, parts): self.db.save() self.send_db(conn, []) - + @waitDecorator def set_g_title(self, conn, parts): title = parts.pop(0).decode() @@ -636,7 +636,7 @@ def set_e_exp(self, conn, parts): self.send_db(conn, []) def check_last_mod(self, obj, time): - return obj.last_mod.timetuple() > time + return obj.last_mod.timetuple() > time def handle_sigterm(self, signum, frame): self.db.lock()