Skip to content

Commit

Permalink
Merge branch 'main' into optional_spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Mar 3, 2022
2 parents ffcd25e + fc618ec commit 59c07ea
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 24 deletions.
8 changes: 5 additions & 3 deletions chia/daemon/keychain_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ def handle_error(self, response: WsRpcMessage):
message = error_details.get("message", "")
raise MalformedKeychainRequest(message)
else:
err = f"{response['data'].get('command')} failed with error: {error}"
self.log.error(f"{err}")
raise Exception(f"{err}")
# Try to construct a more informative error message including the call that failed
if "command" in response["data"]:
err = f"{response['data'].get('command')} failed with error: {error}"
raise Exception(f"{err}")
raise Exception(f"{error}")

async def add_private_key(self, mnemonic: str, passphrase: str) -> PrivateKey:
"""
Expand Down
40 changes: 25 additions & 15 deletions chia/daemon/keychain_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,25 @@ def get_keychain_for_request(self, request: Dict[str, Any]):
return keychain

async def handle_command(self, command, data) -> Dict[str, Any]:
if command == "add_private_key":
return await self.add_private_key(cast(Dict[str, Any], data))
elif command == "check_keys":
return await self.check_keys(cast(Dict[str, Any], data))
elif command == "delete_all_keys":
return await self.delete_all_keys(cast(Dict[str, Any], data))
elif command == "delete_key_by_fingerprint":
return await self.delete_key_by_fingerprint(cast(Dict[str, Any], data))
elif command == "get_all_private_keys":
return await self.get_all_private_keys(cast(Dict[str, Any], data))
elif command == "get_first_private_key":
return await self.get_first_private_key(cast(Dict[str, Any], data))
elif command == "get_key_for_fingerprint":
return await self.get_key_for_fingerprint(cast(Dict[str, Any], data))
return {}
try:
if command == "add_private_key":
return await self.add_private_key(cast(Dict[str, Any], data))
elif command == "check_keys":
return await self.check_keys(cast(Dict[str, Any], data))
elif command == "delete_all_keys":
return await self.delete_all_keys(cast(Dict[str, Any], data))
elif command == "delete_key_by_fingerprint":
return await self.delete_key_by_fingerprint(cast(Dict[str, Any], data))
elif command == "get_all_private_keys":
return await self.get_all_private_keys(cast(Dict[str, Any], data))
elif command == "get_first_private_key":
return await self.get_first_private_key(cast(Dict[str, Any], data))
elif command == "get_key_for_fingerprint":
return await self.get_key_for_fingerprint(cast(Dict[str, Any], data))
return {}
except Exception as e:
log.exception(e)
return {"success": False, "error": str(e), "command": command}

async def add_private_key(self, request: Dict[str, Any]) -> Dict[str, Any]:
if self.get_keychain_for_request(request).is_keyring_locked():
Expand All @@ -93,6 +97,12 @@ async def add_private_key(self, request: Dict[str, Any]) -> Dict[str, Any]:
"error": KEYCHAIN_ERR_KEYERROR,
"error_details": {"message": f"The word '{e.args[0]}' is incorrect.'", "word": e.args[0]},
}
except ValueError as e:
log.exception(e)
return {
"success": False,
"error": str(e),
}

return {"success": True}

Expand Down
5 changes: 5 additions & 0 deletions chia/full_node/block_height_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ async def _load_blocks_from(self, height: uint32, prev_hash: bytes32):
):
return
self.__sub_epoch_summaries[height] = entry[2]
elif height in self.__sub_epoch_summaries:
# if the database file was swapped out and the existing
# cache doesn't represent any of it at all, a missing sub
# epoch summary needs to be removed from the cache too
del self.__sub_epoch_summaries[height]
self.__set_hash(height, prev_hash)
prev_hash = entry[1]

Expand Down
138 changes: 135 additions & 3 deletions tests/core/daemon/test_daemon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from chia.daemon.server import WebSocketServer
from chia.server.outbound_message import NodeType
from chia.types.peer_info import PeerInfo
from tests.block_tools import BlockTools, create_block_tools_async
Expand Down Expand Up @@ -49,8 +50,8 @@ async def get_b_tools(self, get_temp_keyring):

@pytest_asyncio.fixture(scope="function")
async def get_daemon_with_temp_keyring(self, get_b_tools):
async for _ in setup_daemon(btools=get_b_tools):
yield get_b_tools
async for daemon in setup_daemon(btools=get_b_tools):
yield get_b_tools, daemon

@pytest.mark.asyncio
async def test_daemon_simulation(self, simulation, get_b_tools):
Expand Down Expand Up @@ -124,7 +125,7 @@ async def reader(ws, queue):
@pytest.mark.filterwarnings("ignore::DeprecationWarning:websockets.*")
@pytest.mark.asyncio
async def test_validate_keyring_passphrase_rpc(self, get_daemon_with_temp_keyring):
local_b_tools: BlockTools = get_daemon_with_temp_keyring
local_b_tools: BlockTools = get_daemon_with_temp_keyring[0]
keychain = local_b_tools.local_keychain

# When: the keychain has a master passphrase set
Expand Down Expand Up @@ -202,3 +203,134 @@ async def check_empty_passphrase_case(response: aiohttp.http_websocket.WSMessage
await ws.send_str(create_payload("validate_keyring_passphrase", {"key": ""}, "test", "daemon"))
# Expect: validation failure
await check_empty_passphrase_case(await ws.receive())

# Suppress warning: "The explicit passing of coroutine objects to asyncio.wait() is deprecated since Python 3.8..."
# Can be removed when we upgrade to a newer version of websockets (9.1 works)
@pytest.mark.filterwarnings("ignore::DeprecationWarning:websockets.*")
@pytest.mark.asyncio
async def test_add_private_key(self, get_daemon_with_temp_keyring):
local_b_tools: BlockTools = get_daemon_with_temp_keyring[0]
daemon: WebSocketServer = get_daemon_with_temp_keyring[1]
keychain = daemon.keychain_server._default_keychain # Keys will be added here
test_mnemonic = (
"grief lock ketchup video day owner torch young work "
"another venue evidence spread season bright private "
"tomato remind jaguar original blur embody project can"
)
test_fingerprint = 2877570395
mnemonic_with_typo = f"{test_mnemonic}xyz" # intentional typo: can -> canxyz
mnemonic_with_missing_word = " ".join(test_mnemonic.split(" ")[:-1]) # missing last word

async def check_success_case(response: aiohttp.http_websocket.WSMessage):
nonlocal keychain

# Expect: JSON response
assert response.type == aiohttp.WSMsgType.TEXT
message = json.loads(response.data.strip())
# Expect: daemon handled the request
assert message["ack"] is True
# Expect: success flag is set to True
assert message["data"]["success"] is True
# Expect: the keychain has the new key
assert keychain.get_private_key_by_fingerprint(test_fingerprint) is not None

async def check_missing_param_case(response: aiohttp.http_websocket.WSMessage):
# Expect: JSON response
assert response.type == aiohttp.WSMsgType.TEXT
message = json.loads(response.data.strip())
# Expect: daemon handled the request
assert message["ack"] is True
# Expect: success flag is set to False
assert message["data"]["success"] is False
# Expect: error field is set to "malformed request"
assert message["data"]["error"] == "malformed request"
# Expect: error_details message is set to "missing mnemonic and/or passphrase"
assert message["data"]["error_details"]["message"] == "missing mnemonic and/or passphrase"

async def check_mnemonic_with_typo_case(response: aiohttp.http_websocket.WSMessage):
# Expect: JSON response
assert response.type == aiohttp.WSMsgType.TEXT
message = json.loads(response.data.strip())
# Expect: daemon handled the request
assert message["ack"] is True
# Expect: success flag is set to False
assert message["data"]["success"] is False
# Expect: error field is set to "'canxyz' is not in the mnemonic dictionary; may be misspelled"
assert message["data"]["error"] == "'canxyz' is not in the mnemonic dictionary; may be misspelled"

async def check_invalid_mnemonic_length_case(response: aiohttp.http_websocket.WSMessage):
# Expect: JSON response
assert response.type == aiohttp.WSMsgType.TEXT
message = json.loads(response.data.strip())
# Expect: daemon handled the request
assert message["ack"] is True
# Expect: success flag is set to False
assert message["data"]["success"] is False
# Expect: error field is set to "Invalid mnemonic length"
assert message["data"]["error"] == "Invalid mnemonic length"

async def check_invalid_mnemonic_case(response: aiohttp.http_websocket.WSMessage):
# Expect: JSON response
assert response.type == aiohttp.WSMsgType.TEXT
message = json.loads(response.data.strip())
# Expect: daemon handled the request
assert message["ack"] is True
# Expect: success flag is set to False
assert message["data"]["success"] is False
# Expect: error field is set to "Invalid order of mnemonic words"
assert message["data"]["error"] == "Invalid order of mnemonic words"

async with aiohttp.ClientSession() as session:
async with session.ws_connect(
f"wss://127.0.0.1:{local_b_tools._config['daemon_port']}",
autoclose=True,
autoping=True,
heartbeat=60,
ssl=local_b_tools.get_daemon_ssl_context(),
max_msg_size=52428800,
) as ws:
# Expect the key hasn't been added yet
assert keychain.get_private_key_by_fingerprint(test_fingerprint) is None

await ws.send_str(
create_payload("add_private_key", {"mnemonic": test_mnemonic, "passphrase": ""}, "test", "daemon")
)
# Expect: key was added successfully
await check_success_case(await ws.receive())

# When: missing mnemonic
await ws.send_str(create_payload("add_private_key", {"passphrase": ""}, "test", "daemon"))
# Expect: Failure due to missing mnemonic
await check_missing_param_case(await ws.receive())

# When: missing passphrase
await ws.send_str(create_payload("add_private_key", {"mnemonic": test_mnemonic}, "test", "daemon"))
# Expect: Failure due to missing passphrase
await check_missing_param_case(await ws.receive())

# When: using a mmnemonic with an incorrect word (typo)
await ws.send_str(
create_payload(
"add_private_key", {"mnemonic": mnemonic_with_typo, "passphrase": ""}, "test", "daemon"
)
)
# Expect: Failure due to misspelled mnemonic
await check_mnemonic_with_typo_case(await ws.receive())

# When: using a mnemonic with an incorrect word count
await ws.send_str(
create_payload(
"add_private_key", {"mnemonic": mnemonic_with_missing_word, "passphrase": ""}, "test", "daemon"
)
)
# Expect: Failure due to invalid mnemonic
await check_invalid_mnemonic_length_case(await ws.receive())

# When: using using an incorrect mnemnonic
await ws.send_str(
create_payload(
"add_private_key", {"mnemonic": " ".join(["abandon"] * 24), "passphrase": ""}, "test", "daemon"
)
)
# Expect: Failure due to checksum error
await check_invalid_mnemonic_case(await ws.receive())
40 changes: 37 additions & 3 deletions tests/core/full_node/test_block_height_map.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import pytest
import struct
from chia.full_node.block_height_map import BlockHeightMap
from chia.full_node.block_height_map import BlockHeightMap, SesCache
from chia.types.blockchain_format.sub_epoch_summary import SubEpochSummary
from chia.util.db_wrapper import DBWrapper

from tests.util.db_connection import DBConnection
from chia.types.blockchain_format.sized_bytes import bytes32
from typing import Optional
from chia.util.ints import uint8

# from tests.conftest import tmp_dir
from chia.util.files import write_file_async


def gen_block_hash(height: int) -> bytes32:
Expand Down Expand Up @@ -189,6 +188,41 @@ async def test_save_restore(self, tmp_dir, db_version):
with pytest.raises(KeyError) as _:
height_map.get_ses(height)

@pytest.mark.asyncio
async def test_restore_entire_chain(self, tmp_dir, db_version):

# this is a test where the height-to-hash and height-to-ses caches are
# entirely unrelated to the database. Make sure they can both be fully
# replaced
async with DBConnection(db_version) as db_wrapper:

heights = bytearray(900 * 32)
for i in range(900):
idx = i * 32
heights[idx : idx + 32] = bytes([i % 256] * 32)

await write_file_async(tmp_dir / "height-to-hash", heights)

ses_cache = []
for i in range(0, 900, 19):
ses_cache.append((i, gen_ses(i + 9999)))

await write_file_async(tmp_dir / "sub-epoch-summaries", bytes(SesCache(ses_cache)))

await setup_db(db_wrapper)
await setup_chain(db_wrapper, 10000, ses_every=20)

height_map = await BlockHeightMap.create(tmp_dir, db_wrapper)

for height in reversed(range(10000)):
assert height_map.contains_height(height)
assert height_map.get_hash(height) == gen_block_hash(height)
if (height % 20) == 0:
assert height_map.get_ses(height) == gen_ses(height)
else:
with pytest.raises(KeyError) as _:
height_map.get_ses(height)

@pytest.mark.asyncio
async def test_restore_extend(self, tmp_dir, db_version):

Expand Down

0 comments on commit 59c07ea

Please sign in to comment.