diff --git a/UpgradingACA-Py.md b/UpgradingACA-Py.md index a829bc4f7f..c1ef822527 100644 --- a/UpgradingACA-Py.md +++ b/UpgradingACA-Py.md @@ -76,7 +76,19 @@ In case, running multiple tags [say test1 & test2]: ./scripts/run_docker upgrade --force-upgrade --named-tag test1 --named-tag test2 ``` +## Subwallet upgrades +With multitenant enabled, there is a subwallet associated with each tenant profile, so there is a need to upgrade those sub wallets in addition to the base wallet associated with root profile. +There are 2 options to perform such upgrades: + - `--upgrade-all-subwallets` + + This will apply the upgrade steps to all sub wallets [tenant profiles] and the base wallet [root profiles]. + + - `--upgrade-subwallet` + + This will apply the upgrade steps to specified sub wallets [identified by wallet id] and the base wallet. + + Note: multiple specification allowed ## Exceptions diff --git a/aries_cloudagent/commands/tests/test_upgrade.py b/aries_cloudagent/commands/tests/test_upgrade.py index d6c5c66e08..55451387c0 100644 --- a/aries_cloudagent/commands/tests/test_upgrade.py +++ b/aries_cloudagent/commands/tests/test_upgrade.py @@ -4,9 +4,11 @@ from ...core.in_memory import InMemoryProfile from ...connections.models.conn_record import ConnRecord -from ...storage.base import BaseStorage +from ...storage.base import BaseStorage, BaseStorageSearch +from ...storage.in_memory import InMemoryStorage from ...storage.record import StorageRecord from ...version import __version__ +from ...wallet.models.wallet_record import WalletRecord from .. import upgrade as test_module from ..upgrade import UpgradeError @@ -16,12 +18,33 @@ class TestUpgrade(AsyncTestCase): async def setUp(self): self.session = InMemoryProfile.test_session() self.profile = self.session.profile + self.profile.context.injector.bind_instance( + BaseStorageSearch, InMemoryStorage(self.profile) + ) self.storage = self.session.inject(BaseStorage) record = StorageRecord( "acapy_version", "v0.7.2", ) await self.storage.add_record(record) + recs = [ + WalletRecord( + key_management_mode=[ + WalletRecord.MODE_UNMANAGED, + WalletRecord.MODE_MANAGED, + ][i], + settings={ + "wallet.name": f"my-wallet-{i}", + "wallet.type": "indy", + "wallet.key": f"dummy-wallet-key-{i}", + }, + wallet_name=f"my-wallet-{i}", + ) + for i in range(2) + ] + async with self.profile.session() as session: + for rec in recs: + await rec.save(session) def test_bad_calls(self): with self.assertRaises(SystemExit): @@ -85,6 +108,63 @@ async def test_upgrade_from_version(self): profile=self.profile, ) + async def test_upgrade_all_subwallets(self): + self.profile.settings.extend( + { + "upgrade.from_version": "v0.7.2", + "upgrade.upgrade_all_subwallets": True, + "upgrade.force_upgrade": True, + "upgrade.page_size": 1, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + + async def test_upgrade_specified_subwallets(self): + wallet_ids = [] + async with self.profile.session() as session: + wallet_recs = await WalletRecord.query(session, tag_filter={}) + for wallet_rec in wallet_recs: + wallet_ids.append(wallet_rec.wallet_id) + self.profile.settings.extend( + { + "upgrade.named_tags": "fix_issue_rev_reg", + "upgrade.upgrade_subwallets": [wallet_ids[0]], + "upgrade.force_upgrade": True, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + + self.profile.settings.extend( + { + "upgrade.named_tags": "fix_issue_rev_reg", + "upgrade.upgrade_subwallets": wallet_ids, + "upgrade.force_upgrade": True, + "upgrade.page_size": 1, + } + ) + with async_mock.patch.object( + ConnRecord, + "query", + async_mock.CoroutineMock(return_value=[ConnRecord()]), + ), async_mock.patch.object(ConnRecord, "save", async_mock.CoroutineMock()): + await test_module.upgrade( + profile=self.profile, + ) + async def test_upgrade_callable(self): version_storage_record = await self.storage.find_record( type_filter="acapy_version", tag_query={} @@ -412,7 +492,7 @@ async def test_upgrade_x_invalid_config(self): async_mock.MagicMock(return_value={}), ): with self.assertRaises(UpgradeError) as ctx: - await test_module.upgrade(settings={}) + await test_module.upgrade(profile=self.profile) assert "No version configs found in" in str(ctx.exception) async def test_upgrade_x_params(self): diff --git a/aries_cloudagent/commands/upgrade.py b/aries_cloudagent/commands/upgrade.py index bf2d0dc95a..5d81db13fe 100644 --- a/aries_cloudagent/commands/upgrade.py +++ b/aries_cloudagent/commands/upgrade.py @@ -22,23 +22,26 @@ from ..core.profile import Profile, ProfileSession from ..config import argparse as arg +from ..config.injection_context import InjectionContext from ..config.default_context import DefaultContextBuilder from ..config.base import BaseError, BaseSettings from ..config.util import common_config from ..config.wallet import wallet_config from ..messaging.models.base import BaseModelError from ..messaging.models.base_record import BaseRecord, RecordType -from ..storage.base import BaseStorage +from ..storage.base import BaseStorage, BaseStorageSearch from ..storage.error import StorageNotFoundError from ..storage.record import StorageRecord from ..revocation.models.issuer_rev_reg_record import IssuerRevRegRecord from ..utils.classloader import ClassLoader, ClassNotFoundError from ..version import __version__, RECORD_TYPE_ACAPY_VERSION +from ..wallet.models.wallet_record import WalletRecord from . import PROG DEFAULT_UPGRADE_CONFIG_FILE_NAME = "default_version_upgrade_config.yml" LOGGER = logging.getLogger(__name__) +BATCH_SIZE = 25 class ExplicitUpgradeOption(Enum): @@ -239,21 +242,129 @@ def _perform_upgrade( return resave_record_path_sets, executables_call_set +def get_webhook_urls( + base_context: InjectionContext, + wallet_record: WalletRecord, +) -> list: + """Get the webhook urls according to dispatch_type.""" + wallet_id = wallet_record.wallet_id + dispatch_type = wallet_record.wallet_dispatch_type + subwallet_webhook_urls = wallet_record.wallet_webhook_urls or [] + base_webhook_urls = base_context.settings.get("admin.webhook_urls", []) + + if dispatch_type == "both": + webhook_urls = list(set(base_webhook_urls) | set(subwallet_webhook_urls)) + if not webhook_urls: + LOGGER.warning( + "No webhook URLs in context configuration " + f"nor wallet record {wallet_id}, but wallet record " + f"configures dispatch type {dispatch_type}" + ) + elif dispatch_type == "default": + webhook_urls = subwallet_webhook_urls + if not webhook_urls: + LOGGER.warning( + f"No webhook URLs in nor wallet record {wallet_id}, but " + f"wallet record configures dispatch type {dispatch_type}" + ) + else: + webhook_urls = base_webhook_urls + return webhook_urls + + +async def get_wallet_profile( + base_context: InjectionContext, + wallet_record: WalletRecord, + extra_settings: dict = {}, +) -> Profile: + """Get profile for a wallet record.""" + context = base_context.copy() + reset_settings = { + "wallet.recreate": False, + "wallet.seed": None, + "wallet.rekey": None, + "wallet.name": None, + "wallet.type": None, + "mediation.open": None, + "mediation.invite": None, + "mediation.default_id": None, + "mediation.clear": None, + } + extra_settings["admin.webhook_urls"] = get_webhook_urls(base_context, wallet_record) + + context.settings = ( + context.settings.extend(reset_settings) + .extend(wallet_record.settings) + .extend(extra_settings) + ) + + profile, _ = await wallet_config(context, provision=False) + return profile + + async def upgrade( settings: Optional[Union[Mapping[str, Any], BaseSettings]] = None, profile: Optional[Profile] = None, +): + """Invoke upgradation process for each applicable profile.""" + profiles_to_upgrade = [] + if settings: + batch_size = settings.get("upgrade.page_size", BATCH_SIZE) + else: + batch_size = BATCH_SIZE + if profile and (settings or settings == {}): + raise UpgradeError("upgrade requires either profile or settings, not both.") + if profile: + root_profile = profile + settings = profile.settings + else: + context_builder = DefaultContextBuilder(settings) + context = await context_builder.build_context() + root_profile, _ = await wallet_config(context) + profiles_to_upgrade.append(root_profile) + base_storage_search_inst = root_profile.inject(BaseStorageSearch) + if "upgrade.upgrade_all_subwallets" in settings and settings.get( + "upgrade.upgrade_all_subwallets" + ): + search_session = base_storage_search_inst.search_records( + type_filter=WalletRecord.RECORD_TYPE, page_size=batch_size + ) + while search_session._done is False: + wallet_storage_records = await search_session.fetch() + for wallet_storage_record in wallet_storage_records: + wallet_record = WalletRecord.from_storage( + wallet_storage_record.id, + json.loads(wallet_storage_record.value), + ) + wallet_profile = await get_wallet_profile( + base_context=root_profile.context, wallet_record=wallet_record + ) + profiles_to_upgrade.append(wallet_profile) + del settings["upgrade.upgrade_all_subwallets"] + if ( + "upgrade.upgrade_subwallets" in settings + and len(settings.get("upgrade.upgrade_subwallets")) >= 1 + ): + for _wallet_id in settings.get("upgrade.upgrade_subwallets"): + async with root_profile.session() as session: + wallet_record = await WalletRecord.retrieve_by_id( + session, record_id=_wallet_id + ) + wallet_profile = await get_wallet_profile( + base_context=root_profile.context, wallet_record=wallet_record + ) + profiles_to_upgrade.append(wallet_profile) + del settings["upgrade.upgrade_subwallets"] + for _profile in profiles_to_upgrade: + await upgrade_per_profile(profile=_profile, settings=settings) + + +async def upgrade_per_profile( + profile: Profile, + settings: Optional[Union[Mapping[str, Any], BaseSettings]] = None, ): """Perform upgradation steps.""" try: - if profile and (settings or settings == {}): - raise UpgradeError("upgrade requires either profile or settings, not both.") - if profile: - root_profile = profile - settings = profile.settings - else: - context_builder = DefaultContextBuilder(settings) - context = await context_builder.build_context() - root_profile, _ = await wallet_config(context) version_upgrade_config_inst = VersionUpgradeConfig( settings.get("upgrade.config_path") ) @@ -273,7 +384,7 @@ async def upgrade( upgrade_from_version_storage = None upgrade_from_version_config = None upgrade_from_version = None - async with root_profile.session() as session: + async with profile.session() as session: storage = session.inject(BaseStorage) try: version_storage_record = await storage.find_record( @@ -391,8 +502,24 @@ async def upgrade( raise UpgradeError( f"Only BaseRecord can be resaved, found: {str(rec_type)}" ) - async with root_profile.session() as session: - all_records = await rec_type.query(session) + all_records = [] + if settings: + batch_size = settings.get("upgrade.page_size", BATCH_SIZE) + else: + batch_size = BATCH_SIZE + base_storage_search_inst = profile.inject(BaseStorageSearch) + search_session = base_storage_search_inst.search_records( + type_filter=rec_type.RECORD_TYPE, page_size=batch_size + ) + while search_session._done is False: + storage_records = await search_session.fetch() + for storage_record in storage_records: + _record = rec_type.from_storage( + storage_record.id, + json.loads(storage_record.value), + ) + all_records.append(_record) + async with profile.session() as session: for record in all_records: await record.save( session, @@ -406,11 +533,11 @@ async def upgrade( _callable = version_upgrade_config_inst.get_callable(callable_name) if not _callable: raise UpgradeError(f"No function specified for {callable_name}") - await _callable(root_profile) + await _callable(profile) # Update storage version if to_update_flag: - async with root_profile.session() as session: + async with profile.session() as session: storage = session.inject(BaseStorage) if not version_storage_record: await storage.add_record( @@ -428,7 +555,7 @@ async def upgrade( f"set to {upgrade_to_version}" ) if not profile: - await root_profile.close() + await profile.close() except BaseError as e: raise UpgradeError(f"Error during upgrade: {e}") diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index 9e1da6410e..2f9463d73b 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -2085,6 +2085,33 @@ def add_arguments(self, parser: ArgumentParser): help=("Runs upgrade steps associated with tags provided in the config"), ) + parser.add_argument( + "--upgrade-all-subwallets", + action="store_true", + env_var="ACAPY_UPGRADE_ALL_SUBWALLETS", + help="Apply upgrade to all subwallets and the base wallet", + ) + + parser.add_argument( + "--upgrade-subwallet", + action="append", + env_var="ACAPY_UPGRADE_SUBWALLETS", + help=( + "Apply upgrade to specified subwallets (identified by wallet id)" + " and the base wallet" + ), + ) + + parser.add_argument( + "--upgrade-page-size", + type=str, + env_var="ACAPY_UPGRADE_PAGE_SIZE", + help=( + "Specify page/batch size to process BaseRecords, " + "this provides a way to prevent out-of-memory issues." + ), + ) + def get_settings(self, args: Namespace) -> dict: """Extract ACA-Py upgrade process settings.""" settings = {} @@ -2098,4 +2125,15 @@ def get_settings(self, args: Namespace) -> dict: settings["upgrade.named_tags"] = ( list(args.named_tag) if args.named_tag else [] ) + if args.upgrade_all_subwallets: + settings["upgrade.upgrade_all_subwallets"] = args.upgrade_all_subwallets + if args.upgrade_subwallet: + settings["upgrade.upgrade_subwallets"] = ( + list(args.upgrade_subwallet) if args.upgrade_subwallet else [] + ) + if args.upgrade_page_size: + try: + settings["upgrade.page_size"] = int(args.upgrade_page_size) + except ValueError: + raise ArgsParseError("Parameter --upgrade-page-size must be an integer") return settings diff --git a/aries_cloudagent/config/tests/test_argparse.py b/aries_cloudagent/config/tests/test_argparse.py index 0c52c12ef7..37bf4740c1 100644 --- a/aries_cloudagent/config/tests/test_argparse.py +++ b/aries_cloudagent/config/tests/test_argparse.py @@ -155,6 +155,75 @@ async def test_upgrade_config(self): == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" ) + result = parser.parse_args( + [ + "--named-tag", + "test_tag_1", + "--named-tag", + "test_tag_2", + "--force-upgrade", + ] + ) + + assert result.named_tag == ["test_tag_1", "test_tag_2"] + assert result.force_upgrade is True + + settings = group.get_settings(result) + + assert settings.get("upgrade.named_tags") == ["test_tag_1", "test_tag_2"] + assert settings.get("upgrade.force_upgrade") is True + + result = parser.parse_args( + [ + "--upgrade-config-path", + "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml", + "--from-version", + "v0.7.2", + "--upgrade-all-subwallets", + "--force-upgrade", + ] + ) + + assert ( + result.upgrade_config_path + == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" + ) + assert result.force_upgrade is True + assert result.upgrade_all_subwallets is True + + settings = group.get_settings(result) + + assert ( + settings.get("upgrade.config_path") + == "./aries_cloudagent/config/tests/test-acapy-upgrade-config.yml" + ) + assert settings.get("upgrade.force_upgrade") is True + assert settings.get("upgrade.upgrade_all_subwallets") is True + + result = parser.parse_args( + [ + "--named-tag", + "fix_issue_rev_reg", + "--upgrade-subwallet", + "test_wallet_id_1", + "--upgrade-subwallet", + "test_wallet_id_2", + "--force-upgrade", + ] + ) + + assert result.named_tag == ["fix_issue_rev_reg"] + assert result.force_upgrade is True + assert result.upgrade_subwallet == ["test_wallet_id_1", "test_wallet_id_2"] + + settings = group.get_settings(result) + assert settings.get("upgrade.named_tags") == ["fix_issue_rev_reg"] + assert settings.get("upgrade.force_upgrade") is True + assert settings.get("upgrade.upgrade_subwallets") == [ + "test_wallet_id_1", + "test_wallet_id_2", + ] + async def test_outbound_is_required(self): """Test that either -ot or -oq are required""" parser = argparse.create_argument_parser() diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index d2659d5370..907e6181f7 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -6,7 +6,6 @@ import logging from typing import List, Optional, Sequence, Text, Tuple, Union -from multiformats import multibase, multicodec from pydid import ( BaseDIDDocument as ResolvedDocument, DIDCommService, @@ -18,6 +17,7 @@ Ed25519VerificationKey2020, JsonWebKey2020, ) + from ..cache.base import BaseCache from ..config.base import InjectionError from ..core.error import BaseError @@ -40,6 +40,7 @@ from ..storage.error import StorageDuplicateError, StorageError, StorageNotFoundError from ..storage.record import StorageRecord from ..transport.inbound.receipt import MessageReceipt +from ..utils.multiformats import multibase, multicodec from ..wallet.base import BaseWallet from ..wallet.crypto import create_keypair, seed_to_did from ..wallet.did_info import DIDInfo @@ -75,7 +76,6 @@ def __init__(self, profile: Profile): async def create_did_document( self, did_info: DIDInfo, - inbound_connection_id: Optional[str] = None, svc_endpoints: Optional[Sequence[str]] = None, mediation_records: Optional[List[MediationRecord]] = None, ) -> DIDDoc: @@ -83,7 +83,6 @@ async def create_did_document( Args: did_info: The DID information (DID and verkey) used in the connection - inbound_connection_id: The ID of the inbound routing connection to use svc_endpoints: Custom endpoints for the DID Document mediation_record: The record for mediation that contains routing_keys and service endpoint @@ -106,61 +105,18 @@ async def create_did_document( ) did_doc.set(pk) - router_id = inbound_connection_id - routing_keys = [] - router_idx = 1 - while router_id: - # look up routing connection information - async with self._profile.session() as session: - router = await ConnRecord.retrieve_by_id(session, router_id) - if ConnRecord.State.get(router.state) != ConnRecord.State.COMPLETED: - raise BaseConnectionManagerError( - f"Router connection not completed: {router_id}" - ) - routing_doc, _ = await self.fetch_did_document(router.their_did) - assert isinstance(routing_doc, DIDDoc) - if not routing_doc.service: - raise BaseConnectionManagerError( - f"No services defined by routing DIDDoc: {router_id}" - ) - for service in routing_doc.service.values(): - if not service.endpoint: - raise BaseConnectionManagerError( - "Routing DIDDoc service has no service endpoint" - ) - if not service.recip_keys: - raise BaseConnectionManagerError( - "Routing DIDDoc service has no recipient key(s)" - ) - rk = PublicKey( - did_info.did, - f"routing-{router_idx}", - service.recip_keys[0].value, - PublicKeyType.ED25519_SIG_2018, - did_controller, - True, - ) - routing_keys.append(rk) - svc_endpoints = [service.endpoint] - break - router_id = router.inbound_connection_id - + routing_keys: List[str] = [] if mediation_records: for mediation_record in mediation_records: - mediator_routing_keys = [ - PublicKey( - did_info.did, - f"routing-{idx}", - key, - PublicKeyType.ED25519_SIG_2018, - did_controller, # TODO: get correct controller did_info - True, # TODO: should this be true? - ) - for idx, key in enumerate(mediation_record.routing_keys) - ] - - routing_keys = [*routing_keys, *mediator_routing_keys] - svc_endpoints = [mediation_record.endpoint] + ( + mediator_routing_keys, + endpoint, + ) = await self._route_manager.routing_info( + self._profile, mediation_record + ) + routing_keys = [*routing_keys, *(mediator_routing_keys or [])] + if endpoint: + svc_endpoints = [endpoint] for endpoint_index, svc_endpoint in enumerate(svc_endpoints or []): endpoint_ident = "indy" if endpoint_index == 0 else f"indy{endpoint_index}" @@ -933,7 +889,6 @@ async def create_static_connection( # Synthesize their DID doc did_doc = await self.create_did_document( their_info, - None, [their_endpoint or ""], mediation_records=list( filter(None, [base_mediation_record, mediation_record]) diff --git a/aries_cloudagent/connections/models/diddoc/diddoc.py b/aries_cloudagent/connections/models/diddoc/diddoc.py index 0970a2fb0b..e4ced90108 100644 --- a/aries_cloudagent/connections/models/diddoc/diddoc.py +++ b/aries_cloudagent/connections/models/diddoc/diddoc.py @@ -22,6 +22,8 @@ from typing import List, Sequence, Union +from ....did.did_key import DIDKey + from .publickey import PublicKey, PublicKeyType from .service import Service from .util import canon_did, canon_ref, ok_did, resource @@ -116,13 +118,36 @@ def set(self, item: Union[Service, PublicKey]) -> "DIDDoc": "Cannot add item {} to DIDDoc on DID {}".format(item, self.did) ) - def serialize(self) -> dict: + @staticmethod + def _normalize_routing_keys(service: dict) -> dict: + """Normalize routing keys in service. + + Args: + service: service dict + + Returns: service dict with routing keys normalized + """ + routing_keys = service.get("routingKeys") + if routing_keys: + routing_keys = [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in routing_keys + ] + service["routingKeys"] = routing_keys + return service + + def serialize(self, normalize_routing_keys: bool = False) -> dict: """Dump current object to a JSON-compatible dictionary. Returns: dict representation of current DIDDoc """ + service = [service.to_dict() for service in self.service.values()] + if normalize_routing_keys: + service = [self._normalize_routing_keys(s) for s in service] return { "@context": DIDDoc.CONTEXT, @@ -136,7 +161,7 @@ def serialize(self) -> dict: for pubkey in self.pubkey.values() if pubkey.authn ], - "service": [service.to_dict() for service in self.service.values()], + "service": service, } def to_json(self) -> str: @@ -285,7 +310,7 @@ def deserialize(cls, did_doc: dict) -> "DIDDoc": ), service["type"], rv.add_service_pubkeys(service, "recipientKeys"), - rv.add_service_pubkeys(service, ["mediatorKeys", "routingKeys"]), + service.get("routingKeys", []), canon_ref(rv.did, endpoint, ";") if ";" in endpoint else endpoint, service.get("priority", None), ) diff --git a/aries_cloudagent/connections/models/diddoc/service.py b/aries_cloudagent/connections/models/diddoc/service.py index 27d9564d5e..c9d2a8f7a0 100644 --- a/aries_cloudagent/connections/models/diddoc/service.py +++ b/aries_cloudagent/connections/models/diddoc/service.py @@ -36,7 +36,7 @@ def __init__( ident: str, typ: str, recip_keys: Union[Sequence, PublicKey], - routing_keys: Union[Sequence, PublicKey], + routing_keys: List[str], endpoint: str, priority: int = 0, ): @@ -69,13 +69,7 @@ def __init__( if recip_keys else None ) - self._routing_keys = ( - [routing_keys] - if isinstance(routing_keys, PublicKey) - else list(routing_keys) - if routing_keys - else None - ) + self._routing_keys = routing_keys or [] self._endpoint = endpoint self._priority = priority @@ -104,7 +98,7 @@ def recip_keys(self) -> List[PublicKey]: return self._recip_keys @property - def routing_keys(self) -> List[PublicKey]: + def routing_keys(self) -> List[str]: """Accessor for the routing keys.""" return self._routing_keys @@ -128,7 +122,7 @@ def to_dict(self) -> dict: if self.recip_keys: rv["recipientKeys"] = [k.value for k in self.recip_keys] if self.routing_keys: - rv["routingKeys"] = [k.value for k in self.routing_keys] + rv["routingKeys"] = self.routing_keys rv["serviceEndpoint"] = self.endpoint return rv diff --git a/aries_cloudagent/connections/tests/test_base_manager.py b/aries_cloudagent/connections/tests/test_base_manager.py index 20d3d21f0e..aa136cdfbb 100644 --- a/aries_cloudagent/connections/tests/test_base_manager.py +++ b/aries_cloudagent/connections/tests/test_base_manager.py @@ -3,7 +3,6 @@ from unittest.mock import call from asynctest import TestCase as AsyncTestCase, mock as async_mock -from multiformats import multibase, multicodec from pydid import DID, DIDDocument, DIDDocumentBuilder from pydid.doc.builder import ServiceBuilder from pydid.verification_method import ( @@ -32,7 +31,10 @@ from ...protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) -from ...protocols.coordinate_mediation.v1_0.route_manager import RouteManager +from ...protocols.coordinate_mediation.v1_0.route_manager import ( + RouteManager, + CoordinateMediationV1RouteManager, +) from ...protocols.discovery.v2_0.manager import V20DiscoveryMgr from ...resolver.default.key import KeyDIDResolver from ...resolver.default.legacy_peer import LegacyPeerDIDResolver @@ -41,6 +43,7 @@ from ...storage.error import StorageNotFoundError from ...storage.record import StorageRecord from ...transport.inbound.receipt import MessageReceipt +from ...utils.multiformats import multibase, multicodec from ...wallet.base import DIDInfo from ...wallet.did_method import DIDMethods, SOV from ...wallet.error import WalletNotFoundError @@ -82,13 +85,7 @@ async def setUp(self): self.oob_mock = async_mock.MagicMock( clean_finished_oob_record=async_mock.CoroutineMock(return_value=None) ) - self.route_manager = async_mock.MagicMock(RouteManager) - self.route_manager.routing_info = async_mock.CoroutineMock( - return_value=([], self.test_endpoint) - ) - self.route_manager.mediation_record_if_id = async_mock.CoroutineMock( - return_value=None - ) + self.route_manager = CoordinateMediationV1RouteManager() self.resolver = DIDResolver() self.resolver.register_resolver(LegacyPeerDIDResolver()) self.resolver.register_resolver(KeyDIDResolver()) @@ -118,7 +115,7 @@ async def setUp(self): ) self.test_mediator_routing_keys = [ - "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRR" + "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" ] self.test_mediator_conn_id = "mediator-conn-id" self.test_mediator_endpoint = "http://mediator.example.com" @@ -135,176 +132,10 @@ async def test_create_did_document(self): key_type=ED25519, ) - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - did_doc = await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_not_active(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.ABANDONED.rfc23, - ) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_services(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_endpoint(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service(self.test_target_did, "dummy", "IndyAgent", [], [], "", 0) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_recip_keys(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service( - self.test_target_did, - "dummy", - "IndyAgent", - [], - [], - self.test_endpoint, - 0, - ) + did_doc = await self.manager.create_did_document( + did_info=did_info, + svc_endpoints=[self.test_endpoint], ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) async def test_create_did_document_mediation(self): did_info = DIDInfo( @@ -328,8 +159,9 @@ async def test_create_did_document_mediation(self): services = list(doc.service.values()) assert len(services) == 1 (service,) = services - service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] + assert service.routing_keys + service_routing_key = service.routing_keys[0] + assert service_routing_key == mediation_record.routing_keys[0] assert service.endpoint == mediation_record.endpoint async def test_create_did_document_multiple_mediators(self): @@ -351,7 +183,9 @@ async def test_create_did_document_multiple_mediators(self): role=MediationRecord.ROLE_CLIENT, state=MediationRecord.STATE_GRANTED, connection_id="mediator-conn-id2", - routing_keys=["05e8afd1-b4f0-46b7-a285-7a08c8a37caf"], + routing_keys=[ + "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz" + ], endpoint="http://mediatorw.example.com", ) doc = await self.manager.create_did_document( @@ -361,8 +195,8 @@ async def test_create_did_document_multiple_mediators(self): services = list(doc.service.values()) assert len(services) == 1 (service,) = services - assert service.routing_keys[0].value == mediation_record1.routing_keys[0] - assert service.routing_keys[1].value == mediation_record2.routing_keys[0] + assert service.routing_keys[0] == mediation_record1.routing_keys[0] + assert service.routing_keys[1] == mediation_record2.routing_keys[0] assert service.endpoint == mediation_record2.endpoint async def test_create_did_document_mediation_svc_endpoints_overwritten(self): @@ -380,6 +214,9 @@ async def test_create_did_document_mediation_svc_endpoints_overwritten(self): routing_keys=self.test_mediator_routing_keys, endpoint=self.test_mediator_endpoint, ) + self.route_manager.routing_info = async_mock.CoroutineMock( + return_value=(mediation_record.routing_keys, mediation_record.endpoint) + ) doc = await self.manager.create_did_document( did_info, svc_endpoints=[self.test_endpoint], @@ -390,7 +227,7 @@ async def test_create_did_document_mediation_svc_endpoints_overwritten(self): assert len(services) == 1 (service,) = services service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] + assert service_public_keys == mediation_record.routing_keys[0] assert service.endpoint == mediation_record.endpoint async def test_did_key_storage(self): @@ -436,7 +273,13 @@ async def test_store_did_document_with_routing_keys(self): "controller": "YQwDgq9vdAbB3fk1tkeXmg", "type": "Ed25519VerificationKey2018", "publicKeyBase58": "J81x9zdJa8CGSbTYpoYQaNrV6yv13M1Lgz4tmkNPKwZn", - } + }, + { + "id": "YQwDgq9vdAbB3fk1tkeXmg#1", + "controller": "YQwDgq9vdAbB3fk1tkeXmg", + "type": "Ed25519VerificationKey2018", + "publicKeyBase58": routing_key, + }, ], "service": [ { @@ -447,7 +290,7 @@ async def test_store_did_document_with_routing_keys(self): "recipientKeys": [ "J81x9zdJa8CGSbTYpoYQaNrV6yv13M1Lgz4tmkNPKwZn" ], - "routingKeys": ["cK7fwfjpakMuv8QKVv2y6qouZddVw4TxZNQPUs2fFTd"], + "routingKeys": [routing_key], } ], "authentication": [ @@ -1729,6 +1572,7 @@ async def test_create_static_connection_multitenant(self): ) self.multitenant_mgr.get_default_mediator.return_value = None + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True @@ -1761,6 +1605,7 @@ async def test_create_static_connection_multitenant_auto_disclose_features(self) } ) self.multitenant_mgr.get_default_mediator.return_value = None + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True ), async_mock.patch.object( @@ -1790,6 +1635,7 @@ async def test_create_static_connection_multitenant_mediator(self): ) default_mediator = async_mock.MagicMock() + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True @@ -1839,11 +1685,10 @@ async def test_create_static_connection_multitenant_mediator(self): [ call( their_info, - None, [self.test_endpoint], mediation_records=[default_mediator], ), - call(their_info, None, [self.test_endpoint], mediation_records=[]), + call(their_info, [self.test_endpoint], mediation_records=[]), ] ) diff --git a/aries_cloudagent/ledger/indy_vdr.py b/aries_cloudagent/ledger/indy_vdr.py index 1ac34a3c99..d4b7dbada5 100644 --- a/aries_cloudagent/ledger/indy_vdr.py +++ b/aries_cloudagent/ledger/indy_vdr.py @@ -8,7 +8,7 @@ import os.path import tempfile -from datetime import datetime, date +from datetime import datetime, date, timezone from io import StringIO from pathlib import Path from time import time @@ -923,7 +923,7 @@ def taa_rough_timestamp(self) -> int: """ return int( datetime.combine( - date.today(), datetime.min.time(), datetime.timezone.utc + date.today(), datetime.min.time(), timezone.utc ).timestamp() ) diff --git a/aries_cloudagent/messaging/valid.py b/aries_cloudagent/messaging/valid.py index 08bf05f8a7..0838f84af6 100644 --- a/aries_cloudagent/messaging/valid.py +++ b/aries_cloudagent/messaging/valid.py @@ -275,6 +275,37 @@ def __init__(self): ) +class DIDKeyOrRef(Regexp): + """Validate value against DID key specification.""" + + EXAMPLE = "did:key:z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + PATTERN = re.compile(rf"^did:key:z[{B58}]+(?:#z[{B58}]+)?$") + + def __init__(self): + """Initialize the instance.""" + + super().__init__( + DIDKeyOrRef.PATTERN, error="Value {input} is not a did:key or did:key ref" + ) + + +class DIDKeyRef(Regexp): + """Validate value as DID key reference.""" + + EXAMPLE = ( + "did:key:z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + "#z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + ) + PATTERN = re.compile(rf"^did:key:z[{B58}]+#z[{B58}]+$") + + def __init__(self): + """Initialize the instance.""" + + super().__init__( + DIDKeyRef.PATTERN, error="Value {input} is not a did:key reference" + ) + + class DIDWeb(Regexp): """Validate value against did:web specification.""" @@ -854,6 +885,12 @@ def __init__( DID_KEY_VALIDATE = DIDKey() DID_KEY_EXAMPLE = DIDKey.EXAMPLE +DID_KEY_OR_REF_VALIDATE = DIDKeyOrRef() +DID_KEY_OR_REF_EXAMPLE = DIDKeyOrRef.EXAMPLE + +DID_KEY_REF_VALIDATE = DIDKeyRef() +DID_KEY_REF_EXAMPLE = DIDKeyRef.EXAMPLE + DID_POSTURE_VALIDATE = DIDPosture() DID_POSTURE_EXAMPLE = DIDPosture.EXAMPLE diff --git a/aries_cloudagent/multitenant/route_manager.py b/aries_cloudagent/multitenant/route_manager.py index 954b3c98f9..03798f47ce 100644 --- a/aries_cloudagent/multitenant/route_manager.py +++ b/aries_cloudagent/multitenant/route_manager.py @@ -2,7 +2,7 @@ import logging -from typing import List, Optional, Tuple +from typing import List, Optional from ..connections.models.conn_record import ConnRecord from ..core.profile import Profile @@ -11,10 +11,14 @@ from ..protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) -from ..protocols.coordinate_mediation.v1_0.normalization import normalize_from_did_key +from ..protocols.coordinate_mediation.v1_0.normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) from ..protocols.coordinate_mediation.v1_0.route_manager import ( CoordinateMediationV1RouteManager, RouteManager, + RoutingInfo, ) from ..protocols.routing.v1_0.manager import RoutingManager from ..protocols.routing.v1_0.models.route_record import RouteRecord @@ -98,17 +102,35 @@ async def _route_for_key( return keylist_updates + async def mediation_records_for_connection( + self, + profile: Profile, + conn_record: ConnRecord, + mediation_id: Optional[str] = None, + or_default: bool = False, + ) -> List[MediationRecord]: + """Determine mediation records for a connection.""" + conn_specific = await super().mediation_records_for_connection( + profile, conn_record, mediation_id, or_default + ) + base_mediation_record = await self.get_base_wallet_mediator() + return [ + record + for record in (base_mediation_record, *conn_specific) + if record is not None + ] + async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Return routing info.""" routing_keys = [] base_mediation_record = await self.get_base_wallet_mediator() + my_endpoint = None if base_mediation_record: routing_keys = base_mediation_record.routing_keys my_endpoint = base_mediation_record.endpoint @@ -117,7 +139,9 @@ async def routing_info( routing_keys = [*routing_keys, *mediation_record.routing_keys] my_endpoint = mediation_record.endpoint - return routing_keys, my_endpoint + routing_keys = [normalize_to_did_key(key).key_id for key in routing_keys] + + return RoutingInfo(routing_keys or None, my_endpoint) class BaseWalletRouteManager(CoordinateMediationV1RouteManager): diff --git a/aries_cloudagent/multitenant/tests/test_route_manager.py b/aries_cloudagent/multitenant/tests/test_route_manager.py index e4a537b7d1..2aebc0b2e9 100644 --- a/aries_cloudagent/multitenant/tests/test_route_manager.py +++ b/aries_cloudagent/multitenant/tests/test_route_manager.py @@ -18,6 +18,8 @@ TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF2 = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz" @pytest.fixture @@ -292,12 +294,10 @@ async def test_routing_info_with_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://mediator.example.com", ) - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", mediation_record - ) + keys, endpoint = await route_manager.routing_info(sub_profile, mediation_record) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -307,11 +307,9 @@ async def test_routing_info_no_mediator( sub_profile: Profile, route_manager: MultitenantRouteManager, ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", None - ) - assert keys == [] - assert endpoint == "http://example.com" + keys, endpoint = await route_manager.routing_info(sub_profile, None) + assert keys is None + assert endpoint is None @pytest.mark.asyncio @@ -322,7 +320,7 @@ async def test_routing_info_with_base_mediator( base_mediation_record = MediationRecord( mediation_id="test-base-mediation-id", connection_id="test-base-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://base.mediator.example.com", ) @@ -331,9 +329,7 @@ async def test_routing_info_with_base_mediator( "get_base_wallet_mediator", mock.CoroutineMock(return_value=base_mediation_record), ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", None - ) + keys, endpoint = await route_manager.routing_info(sub_profile, None) assert keys == base_mediation_record.routing_keys assert endpoint == base_mediation_record.endpoint @@ -346,13 +342,13 @@ async def test_routing_info_with_base_mediator_and_sub_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF2], endpoint="http://mediator.example.com", ) base_mediation_record = MediationRecord( mediation_id="test-base-mediation-id", connection_id="test-base-mediator-conn-id", - routing_keys=["test-base-key-0", "test-base-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://base.mediator.example.com", ) @@ -361,9 +357,7 @@ async def test_routing_info_with_base_mediator_and_sub_mediator( "get_base_wallet_mediator", mock.CoroutineMock(return_value=base_mediation_record), ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", mediation_record - ) + keys, endpoint = await route_manager.routing_info(sub_profile, mediation_record) assert keys == [*base_mediation_record.routing_keys, *mediation_record.routing_keys] assert endpoint == mediation_record.endpoint diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index fe33e1dac3..7a77cd6739 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -11,7 +11,6 @@ from ....core.profile import Profile from ....messaging.responder import BaseResponder from ....messaging.valid import IndyDID -from ....multitenant.base import BaseMultitenantManager from ....storage.error import StorageNotFoundError from ....transport.inbound.receipt import MessageReceipt from ....wallet.base import BaseWallet @@ -55,16 +54,16 @@ def profile(self) -> Profile: async def create_invitation( self, - my_label: str = None, - my_endpoint: str = None, - auto_accept: bool = None, + my_label: Optional[str] = None, + my_endpoint: Optional[str] = None, + auto_accept: Optional[bool] = None, public: bool = False, multi_use: bool = False, - alias: str = None, - routing_keys: Sequence[str] = None, - recipient_keys: Sequence[str] = None, - metadata: dict = None, - mediation_id: str = None, + alias: Optional[str] = None, + routing_keys: Optional[Sequence[str]] = None, + recipient_keys: Optional[Sequence[str]] = None, + metadata: Optional[dict] = None, + mediation_id: Optional[str] = None, ) -> Tuple[ConnRecord, ConnectionInvitation]: """Generate new connection invitation. @@ -208,11 +207,15 @@ async def create_invitation( await self._route_manager.route_invitation( self.profile, connection, mediation_record ) - routing_keys, my_endpoint = await self._route_manager.routing_info( + routing_keys, routing_endpoint = await self._route_manager.routing_info( self.profile, - my_endpoint or cast(str, self.profile.settings.get("default_endpoint")), mediation_record, ) + my_endpoint = ( + routing_endpoint + or my_endpoint + or cast(str, self.profile.settings.get("default_endpoint")) + ) # Create connection invitation message # Note: Need to split this into two stages @@ -336,20 +339,13 @@ async def create_request( """ - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, connection, mediation_id, or_default=True, ) - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if connection.my_did: async with self.profile.session() as session: wallet = session.inject(BaseWallet) @@ -363,7 +359,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - self.profile, connection, mediation_record + self.profile, connection, mediation_records ) # Create connection request message @@ -378,11 +374,8 @@ async def create_request( did_doc = await self.create_did_document( my_info, - connection.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) if not my_label: @@ -587,18 +580,10 @@ async def create_response( settings=self.profile.settings, ) - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, connection, mediation_id ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if ConnRecord.State.get(connection.state) not in ( ConnRecord.State.REQUEST, ConnRecord.State.RESPONSE, @@ -622,7 +607,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - self.profile, connection, mediation_record + self.profile, connection, mediation_records ) # Create connection response message @@ -637,11 +622,8 @@ async def create_response( did_doc = await self.create_did_document( my_info, - connection.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) response = ConnectionResponse( diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py index 79a402b356..03592102f1 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py @@ -3,8 +3,9 @@ from typing import Sequence from urllib.parse import parse_qs, urljoin, urlparse -from marshmallow import EXCLUDE, ValidationError, fields, validates_schema +from marshmallow import EXCLUDE, ValidationError, fields, pre_load, validates_schema +from .....did.did_key import DIDKey from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.valid import ( GENERIC_DID_EXAMPLE, @@ -58,6 +59,16 @@ def __init__( self.recipient_keys = list(recipient_keys) if recipient_keys else None self.endpoint = endpoint self.routing_keys = list(routing_keys) if routing_keys else None + self.routing_keys = ( + [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in self.routing_keys + ] + if self.routing_keys + else None + ) self.image_url = image_url def to_url(self, base_url: str = None) -> str: @@ -157,6 +168,19 @@ class Meta: }, ) + @pre_load + def transform_routing_keys(self, data, **kwargs): + """Transform routingKeys from did:key refs, if necessary.""" + routing_keys = data.get("routingKeys") + if routing_keys: + data["routingKeys"] = [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in routing_keys + ] + return data + @validates_schema def validate_fields(self, data, **kwargs): """Validate schema fields. diff --git a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py index dfe460c464..fd842c9d13 100644 --- a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py +++ b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py @@ -10,7 +10,7 @@ class DIDDocWrapper(fields.Field): """Field that loads and serializes DIDDoc.""" - def _serialize(self, value, attr, obj, **kwargs): + def _serialize(self, value: DIDDoc, attr, obj, **kwargs): """Serialize the DIDDoc. Args: @@ -20,7 +20,7 @@ def _serialize(self, value, attr, obj, **kwargs): The serialized DIDDoc """ - return value.serialize() + return value.serialize(normalize_routing_keys=True) def _deserialize(self, value, attr=None, data=None, **kwargs): """Deserialize a value into a DIDDoc. diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index b881d2fd34..ee8bfd1a66 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -315,7 +315,7 @@ async def test_create_invitation_mediation_using_default(self): assert invite.routing_keys == self.test_mediator_routing_keys assert invite.endpoint == self.test_mediator_endpoint self.route_manager.routing_info.assert_awaited_once_with( - self.profile, self.test_endpoint, mediation_record + self.profile, mediation_record ) async def test_receive_invitation(self): @@ -426,15 +426,11 @@ async def test_create_request_multitenant(self): with async_mock.patch.object( InMemoryWallet, "create_local_did", autospec=True ) as mock_wallet_create_local_did, async_mock.patch.object( - self.multitenant_mgr, - "get_default_mediator", - async_mock.CoroutineMock(return_value=mediation_record), - ), async_mock.patch.object( ConnectionManager, "create_did_document", autospec=True ) as create_did_document, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=None), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): mock_wallet_create_local_did.return_value = DIDInfo( self.test_did, @@ -455,7 +451,6 @@ async def test_create_request_multitenant(self): create_did_document.assert_called_once_with( self.manager, mock_wallet_create_local_did.return_value, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -487,8 +482,8 @@ async def test_create_request_mediation_id(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): did_info = DIDInfo( did=self.test_did, @@ -507,7 +502,6 @@ async def test_create_request_mediation_id(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -539,8 +533,8 @@ async def test_create_request_default_mediator(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): did_info = DIDInfo( did=self.test_did, @@ -558,7 +552,6 @@ async def test_create_request_default_mediator(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -881,10 +874,6 @@ async def test_create_response_multitenant(self): ConnRecord, "save", autospec=True ), async_mock.patch.object( ConnRecord, "metadata_get", async_mock.CoroutineMock(return_value=False) - ), async_mock.patch.object( - self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), ), async_mock.patch.object( ConnRecord, "retrieve_request", autospec=True ), async_mock.patch.object( @@ -892,15 +881,11 @@ async def test_create_response_multitenant(self): ), async_mock.patch.object( InMemoryWallet, "create_local_did", autospec=True ) as mock_wallet_create_local_did, async_mock.patch.object( - self.multitenant_mgr, - "get_default_mediator", - async_mock.CoroutineMock(return_value=mediation_record), - ), async_mock.patch.object( ConnectionManager, "create_did_document", autospec=True ) as create_did_document, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=None), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): mock_wallet_create_local_did.return_value = DIDInfo( self.test_did, @@ -918,7 +903,6 @@ async def test_create_response_multitenant(self): create_did_document.assert_called_once_with( self.manager, mock_wallet_create_local_did.return_value, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -970,8 +954,8 @@ async def test_create_response_mediation(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ), async_mock.patch.object( record, "retrieve_request", autospec=True ), async_mock.patch.object( @@ -994,7 +978,6 @@ async def test_create_response_mediation(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py index e8924dbb83..d3a39dbebb 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py @@ -1,9 +1,10 @@ """Test mediate grant message handler.""" import pytest -from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock +from aries_cloudagent.core.profile import ProfileSession + from ......connections.models.conn_record import ConnRecord from ......messaging.base_handler import HandlerException from ......messaging.request_context import RequestContext @@ -18,69 +19,86 @@ from .. import mediation_grant_handler as test_module TEST_CONN_ID = "conn-id" -TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" -TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" +TEST_BASE58_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ENDPOINT = "https://example.com" -class TestMediationGrantHandler(AsyncTestCase): - """Test mediate grant message handler.""" +@pytest.fixture() +async def context(): + context = RequestContext.test_context() + context.message = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]) + context.connection_ready = True + context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) + yield context - async def setUp(self): - """Setup test dependencies.""" - self.context = RequestContext.test_context() - self.session = await self.context.session() - self.context.message = MediationGrant( - endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY] - ) - self.context.connection_ready = True - self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) - async def test_handler_no_active_connection(self): +@pytest.fixture() +async def session(context: RequestContext): + yield await context.session() + + +@pytest.mark.asyncio +class TestMediationGrantHandler: + """Test mediate grant message handler.""" + + async def test_handler_no_active_connection(self, context: RequestContext): handler, responder = MediationGrantHandler(), MockResponder() - self.context.connection_ready = False + context.connection_ready = False with pytest.raises(HandlerException) as exc: - await handler.handle(self.context, responder) + await handler.handle(context, responder) assert "no active connection" in str(exc.value) - async def test_handler_no_mediation_record(self): + async def test_handler_no_mediation_record(self, context: RequestContext): handler, responder = MediationGrantHandler(), MockResponder() with pytest.raises(HandlerException) as exc: - await handler.handle(self.context, responder) + await handler.handle(context, responder) assert "has not been requested" in str(exc.value) - async def test_handler(self): + @pytest.mark.parametrize( + "grant", + [ + MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]), + MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_BASE58_VERKEY]), + ], + ) + async def test_handler( + self, grant: MediationGrant, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() - await MediationRecord(connection_id=TEST_CONN_ID).save(self.session) - await handler.handle(self.context, responder) - record = await MediationRecord.retrieve_by_connection_id( - self.session, TEST_CONN_ID - ) + await MediationRecord(connection_id=TEST_CONN_ID).save(session) + context.message = grant + await handler.handle(context, responder) + record = await MediationRecord.retrieve_by_connection_id(session, TEST_CONN_ID) assert record assert record.state == MediationRecord.STATE_GRANTED assert record.endpoint == TEST_ENDPOINT - assert record.routing_keys == [TEST_RECORD_VERKEY] + assert record.routing_keys == [TEST_VERKEY] - async def test_handler_connection_has_set_to_default_meta(self): + async def test_handler_connection_has_set_to_default_meta( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object( - self.context.connection_record, + context.connection_record, "metadata_get", async_mock.CoroutineMock(return_value=True), ), async_mock.patch.object( test_module, "MediationManager", autospec=True ) as mock_mediation_manager: - await handler.handle(self.context, responder) + await handler.handle(context, responder) mock_mediation_manager.return_value.set_default_mediator.assert_called_once_with( record ) - async def test_handler_multitenant_base_mediation(self): + async def test_handler_multitenant_base_mediation( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), async_mock.CoroutineMock() responder.send = async_mock.CoroutineMock() - profile = self.context.profile + profile = context.profile profile.context.update_settings( {"multitenant.enabled": True, "wallet.id": "test_wallet"} @@ -94,28 +112,30 @@ async def test_handler_multitenant_base_mediation(self): multitenant_mgr.get_default_mediator.return_value = default_base_mediator record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object(MediationManager, "add_key") as add_key: keylist_updates = async_mock.MagicMock() add_key.return_value = keylist_updates - await handler.handle(self.context, responder) + await handler.handle(context, responder) add_key.assert_called_once_with("key2") responder.send.assert_called_once_with( keylist_updates, connection_id=TEST_CONN_ID ) - async def test_handler_connection_no_set_to_default(self): + async def test_handler_connection_no_set_to_default( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object( - self.context.connection_record, + context.connection_record, "metadata_get", async_mock.CoroutineMock(return_value=False), ), async_mock.patch.object( test_module, "MediationManager", autospec=True ) as mock_mediation_manager: - await handler.handle(self.context, responder) + await handler.handle(context, responder) mock_mediation_manager.return_value.set_default_mediator.assert_not_called() diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py index 0ab45b1434..a97055da63 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py @@ -26,7 +26,10 @@ from .messages.mediate_grant import MediationGrant from .messages.mediate_request import MediationRequest from .models.mediation_record import MediationRecord -from .normalization import normalize_from_did_key +from .normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) LOGGER = logging.getLogger(__name__) @@ -176,7 +179,7 @@ async def grant_request( await mediation_record.save(session, reason="Mediation request granted") grant = MediationGrant( endpoint=session.settings.get("default_endpoint"), - routing_keys=[routing_did.verkey], + routing_keys=[normalize_to_did_key(routing_did.verkey).key_id], ) return mediation_record, grant @@ -458,11 +461,9 @@ async def request_granted(self, record: MediationRecord, grant: MediationGrant): """ record.state = MediationRecord.STATE_GRANTED record.endpoint = grant.endpoint - # record.routing_keys = grant.routing_keys - routing_keys = [] - for key in grant.routing_keys: - routing_keys.append(normalize_from_did_key(key)) - record.routing_keys = routing_keys + record.routing_keys = [ + normalize_to_did_key(key).key_id for key in grant.routing_keys + ] async with self._profile.session() as session: await record.save(session, reason="Mediation request granted.") diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py index d2595ede53..551f795eac 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py @@ -9,7 +9,6 @@ from .....messaging.agent_message import AgentMessage, AgentMessageSchema from ..message_types import MEDIATE_GRANT, PROTOCOL_PACKAGE -from ..normalization import normalize_from_public_key HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.mediation_grant_handler.MediationGrantHandler" @@ -41,11 +40,7 @@ def __init__( """ super(MediationGrant, self).__init__(**kwargs) self.endpoint = endpoint - self.routing_keys = ( - [normalize_from_public_key(key) for key in routing_keys] - if routing_keys - else [] - ) + self.routing_keys = routing_keys or [] class MediationGrantSchema(AgentMessageSchema): diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py index d699565367..28fee2ce89 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py @@ -1,4 +1,5 @@ """Normalization methods used while transitioning to DID:Key method.""" +from typing import Union from ....did.did_key import DIDKey from ....wallet.key_type import ED25519 @@ -17,3 +18,12 @@ def normalize_from_public_key(key: str): return key return DIDKey.from_public_key_b58(key, ED25519).did + + +def normalize_to_did_key(value: Union[str, DIDKey]) -> DIDKey: + """Normalize a value to a DIDKey.""" + if isinstance(value, DIDKey): + return value + if value.startswith("did:key:"): + return DIDKey.from_did(value) + return DIDKey.from_public_key_b58(value, ED25519) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py index 990252cd5a..c83e7984ec 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod import logging -from typing import List, Optional, Tuple +from typing import List, NamedTuple, Optional from ....connections.models.conn_record import ConnRecord from ....core.profile import Profile @@ -20,7 +20,10 @@ from .manager import MediationManager from .messages.keylist_update import KeylistUpdate from .models.mediation_record import MediationRecord -from .normalization import normalize_from_did_key +from .normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) LOGGER = logging.getLogger(__name__) @@ -30,6 +33,18 @@ class RouteManagerError(Exception): """Raised on error from route manager.""" +class RoutingInfo(NamedTuple): + """Routing info tuple contiaing routing keys and endpoint.""" + + routing_keys: Optional[List[str]] + endpoint: Optional[str] + + @classmethod + def empty(cls): + """Empty routing info.""" + return cls(routing_keys=None, endpoint=None) + + class RouteManager(ABC): """Base Route Manager.""" @@ -59,14 +74,15 @@ def _validate_mediation_state(self, mediation_record: MediationRecord): f"{mediation_record.mediation_id}" ) - async def mediation_record_for_connection( + async def mediation_records_for_connection( self, profile: Profile, conn_record: ConnRecord, mediation_id: Optional[str] = None, or_default: bool = False, - ): + ) -> List[MediationRecord]: """Return relevant mediator for connection.""" + # TODO Support multiple mediators? if conn_record.connection_id: async with profile.session() as session: mediation_metadata = await conn_record.metadata_get( @@ -83,7 +99,7 @@ async def mediation_record_for_connection( await self.save_mediator_for_connection( profile, conn_record, mediation_record ) - return mediation_record + return [mediation_record] if mediation_record else [] async def mediation_record_if_id( self, @@ -126,11 +142,13 @@ async def route_connection_as_invitee( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the invitee.""" LOGGER.debug("Routing connection as invitee") my_info = await self.get_or_create_my_did(profile, conn_record) + # Only most destward mediator receives keylist updates + mediation_record = mediation_records[0] if mediation_records else None return await self._route_for_key( profile, my_info.verkey, mediation_record, skip_if_exists=True ) @@ -139,7 +157,7 @@ async def route_connection_as_inviter( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the inviter.""" LOGGER.debug("Routing connection as inviter") @@ -154,6 +172,9 @@ async def route_connection_as_inviter( if public_did and public_did.verkey == conn_record.invitation_key: replace_key = None + # Only most destward mediator receives keylist updates + mediation_record = mediation_records[0] if mediation_records else None + return await self._route_for_key( profile, my_info.verkey, @@ -166,7 +187,7 @@ async def route_connection( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a connection. @@ -176,14 +197,14 @@ async def route_connection( ConnRecord.Role.RESPONDER ): return await self.route_connection_as_invitee( - profile, conn_record, mediation_record + profile, conn_record, mediation_records ) if conn_record.rfc23_state == ConnRecord.State.REQUEST.rfc23strict( ConnRecord.Role.REQUESTER ): return await self.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, mediation_records ) return None @@ -255,9 +276,8 @@ async def save_mediator_for_connection( async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Retrieve routing keys.""" async def connection_from_recipient_key( @@ -321,11 +341,16 @@ async def _route_for_key( async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Return routing info for mediator.""" if mediation_record: - return mediation_record.routing_keys, mediation_record.endpoint + return RoutingInfo( + routing_keys=[ + normalize_to_did_key(key).key_id + for key in mediation_record.routing_keys + ], + endpoint=mediation_record.endpoint, + ) - return [], my_endpoint + return RoutingInfo.empty() diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py index 52b87058fd..11231c36d7 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py @@ -498,14 +498,14 @@ async def update_keylist_for_connection(request: web.BaseRequest): async with context.session() as session: connection_record = await ConnRecord.retrieve_by_id(session, connection_id) - mediation_record = await route_manager.mediation_record_for_connection( + mediation_records = await route_manager.mediation_records_for_connection( context.profile, connection_record, mediation_id, or_default=True ) # MediationRecord is permitted to be None; route manager will # ensure the correct mediator is notified. keylist_update = await route_manager.route_connection( - context.profile, connection_record, mediation_record + context.profile, connection_record, mediation_records ) results = keylist_update.serialize() if keylist_update else {} diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py index 2ebca54939..95349e2ef8 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py @@ -28,10 +28,10 @@ TEST_CONN_ID = "conn-id" TEST_THREAD_ID = "thread-id" TEST_ENDPOINT = "https://example.com" -TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" -TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" +TEST_BASE58_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" -TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" pytestmark = pytest.mark.asyncio @@ -120,7 +120,7 @@ async def test_grant_request(self, session, manager): routing_key = await manager._retrieve_routing_did(session) routing_key = DIDKey.from_public_key_b58( routing_key.verkey, routing_key.key_type - ).did + ).key_id assert grant.routing_keys == [routing_key] async def test_deny_request(self, manager): @@ -133,7 +133,7 @@ async def test_deny_request(self, manager): async def test_update_keylist_delete(self, session, manager, record): """test_update_keylist_delete.""" await RouteRecord( - connection_id=TEST_CONN_ID, recipient_key=TEST_RECORD_VERKEY + connection_id=TEST_CONN_ID, recipient_key=TEST_BASE58_VERKEY ).save(session) response = await manager.update_keylist( record=record, @@ -168,7 +168,7 @@ async def test_update_keylist_create(self, manager, record): async def test_update_keylist_create_existing(self, session, manager, record): """test_update_keylist_create_existing.""" await RouteRecord( - connection_id=TEST_CONN_ID, recipient_key=TEST_RECORD_VERKEY + connection_id=TEST_CONN_ID, recipient_key=TEST_BASE58_VERKEY ).save(session) response = await manager.update_keylist( record=record, @@ -272,14 +272,25 @@ async def test_prepare_request(self, manager): assert record.connection_id == TEST_CONN_ID assert request - async def test_request_granted(self, manager): + async def test_request_granted_base58(self, manager): """test_request_granted.""" record, _ = await manager.prepare_request(TEST_CONN_ID) - grant = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_ROUTE_VERKEY]) + grant = MediationGrant( + endpoint=TEST_ENDPOINT, routing_keys=[TEST_BASE58_VERKEY] + ) + await manager.request_granted(record, grant) + assert record.state == MediationRecord.STATE_GRANTED + assert record.endpoint == TEST_ENDPOINT + assert record.routing_keys == [TEST_VERKEY] + + async def test_request_granted_did_key(self, manager): + """test_request_granted.""" + record, _ = await manager.prepare_request(TEST_CONN_ID) + grant = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]) await manager.request_granted(record, grant) assert record.state == MediationRecord.STATE_GRANTED assert record.endpoint == TEST_ENDPOINT - assert record.routing_keys == [TEST_ROUTE_RECORD_VERKEY] + assert record.routing_keys == [TEST_VERKEY] async def test_request_denied(self, manager): """test_request_denied.""" diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py index 4d9efceb72..f543efe386 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py @@ -26,6 +26,9 @@ TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF2 = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz" class MockRouteManager(RouteManager): @@ -51,7 +54,7 @@ def route_manager(): manager._route_for_key = mock.CoroutineMock( return_value=mock.MagicMock(KeylistUpdate) ) - manager.routing_info = mock.CoroutineMock(return_value=([], "http://example.com")) + manager.routing_info = mock.CoroutineMock(return_value=([], None)) yield manager @@ -113,12 +116,9 @@ async def test_mediation_record_for_connection_mediation_id( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, mediation_record.mediation_id - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, mediation_record.mediation_id + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with( profile, mediation_record.mediation_id, False ) @@ -139,12 +139,9 @@ async def test_mediation_record_for_connection_mediation_metadata( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, "another-mediation-id" - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, "another-mediation-id" + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with( profile, mediation_record.mediation_id, False ) @@ -162,12 +159,9 @@ async def test_mediation_record_for_connection_default( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, None, or_default=True - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, None, or_default=True + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with(profile, None, True) @@ -285,7 +279,7 @@ async def test_route_connection_as_invitee( mock.CoroutineMock(return_value=mock_did_info), ): await route_manager.route_connection_as_invitee( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, mock_did_info.verkey, mediation_record, skip_if_exists=True @@ -305,7 +299,7 @@ async def test_route_connection_as_inviter( mock.CoroutineMock(return_value=mock_did_info), ): await route_manager.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, @@ -342,7 +336,7 @@ async def test_route_connection_state_inviter_replace_key_none( ), ): await route_manager.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, @@ -365,7 +359,7 @@ async def test_route_connection_state_invitee( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) mock_route_connection_as_invitee.assert_called_once() mock_route_connection_as_inviter.assert_not_called() @@ -382,7 +376,7 @@ async def test_route_connection_state_inviter( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) mock_route_connection_as_inviter.assert_called_once() mock_route_connection_as_invitee.assert_not_called() @@ -395,7 +389,7 @@ async def test_route_connection_state_other( conn_record.state = "response" conn_record.their_role = "requester" assert ( - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) is None ) @@ -696,11 +690,11 @@ async def test_mediation_routing_info_with_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://mediator.example.com", ) keys, endpoint = await mediation_route_manager.routing_info( - profile, "http://example.com", mediation_record + profile, mediation_record ) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -711,8 +705,6 @@ async def test_mediation_routing_info_no_mediator( profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, ): - keys, endpoint = await mediation_route_manager.routing_info( - profile, "http://example.com", None - ) - assert keys == [] - assert endpoint == "http://example.com" + keys, endpoint = await mediation_route_manager.routing_info(profile, None) + assert keys is None + assert endpoint is None diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 0db2702e51..1b36f4c50d 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -17,7 +17,6 @@ from ....did.did_key import DIDKey from ....messaging.decorators.attach_decorator import AttachDecorator from ....messaging.responder import BaseResponder -from ....multitenant.base import BaseMultitenantManager from ....resolver.base import ResolverError from ....resolver.did_resolver import DIDResolver from ....storage.error import StorageNotFoundError @@ -285,21 +284,13 @@ async def create_request( """ # Mediation Support - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, conn_rec, mediation_id, or_default=True, ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - my_info = None if conn_rec.my_did: @@ -336,11 +327,8 @@ async def create_request( else: did_doc = await self.create_did_document( my_info, - conn_rec.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) attach = AttachDecorator.data_base64(did_doc.serialize()) async with self.profile.session() as session: @@ -377,7 +365,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - self.profile, conn_rec, mediation_record + self.profile, conn_rec, mediation_records ) return request @@ -599,18 +587,10 @@ async def create_response( settings=self.profile.settings, ) - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, conn_rec, mediation_id ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if ConnRecord.State.get(conn_rec.state) is not ConnRecord.State.REQUEST: raise DIDXManagerError( f"Connection not in state {ConnRecord.State.REQUEST.rfc23}" @@ -645,7 +625,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - self.profile, conn_rec, mediation_record + self.profile, conn_rec, mediation_records ) # Create connection response message @@ -665,11 +645,8 @@ async def create_response( else: did_doc = await self.create_did_document( my_info, - conn_rec.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) attach = AttachDecorator.data_base64(did_doc.serialize()) async with self.profile.session() as session: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 8bcbce0603..c0e1edb913 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -2198,177 +2198,10 @@ async def test_create_did_document(self): key_type=ED25519, ) - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - did_doc = self.make_did_doc( - did=TestConfig.test_target_did, - verkey=TestConfig.test_target_verkey, - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - did_doc = await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_not_completed(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.ABANDONED.rfc23, - ) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_services(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey + did_doc = await self.manager.create_did_document( + did_info=did_info, + svc_endpoints=[TestConfig.test_endpoint], ) - x_did_doc._service = {} - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_service_endpoint(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service(TestConfig.test_target_did, "dummy", "IndyAgent", [], [], "", 0) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_service_recip_keys(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service( - TestConfig.test_target_did, - "dummy", - "IndyAgent", - [], - [], - TestConfig.test_endpoint, - 0, - ) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) async def test_did_key_storage(self): did_info = DIDInfo( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 007b43a675..2dfbe7f421 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -317,9 +317,10 @@ async def create_invitation( async with self.profile.session() as session: await conn_rec.save(session, reason="Created new connection") - routing_keys, my_endpoint = await self._route_manager.routing_info( - self.profile, my_endpoint, mediation_record + routing_keys, routing_endpoint = await self._route_manager.routing_info( + self.profile, mediation_record ) + my_endpoint = routing_endpoint or my_endpoint if not conn_rec: our_service = ServiceDecorator( @@ -335,8 +336,8 @@ async def create_invitation( routing_keys = [ key if len(key.split(":")) == 3 - else DIDKey.from_public_key_b58(key, ED25519).did - for key in routing_keys + else DIDKey.from_public_key_b58(key, ED25519).key_id + for key in routing_keys or [] ] # Create connection invitation message @@ -353,7 +354,9 @@ async def create_invitation( _id="#inline", _type="did-communication", recipient_keys=[ - DIDKey.from_public_key_b58(connection_key.verkey, ED25519).did + DIDKey.from_public_key_b58( + connection_key.verkey, ED25519 + ).key_id ], service_endpoint=my_endpoint, routing_keys=routing_keys, @@ -814,11 +817,11 @@ async def _perform_handshake( "id": "#inline", "type": "did-communication", "recipientKeys": [ - DIDKey.from_public_key_b58(key, ED25519).did + DIDKey.from_public_key_b58(key, ED25519).key_id for key in recipient_keys ], "routingKeys": [ - DIDKey.from_public_key_b58(key, ED25519).did + DIDKey.from_public_key_b58(key, ED25519).key_id for key in routing_keys ], "serviceEndpoint": endpoint, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py index aca81b20f7..92ac2a7b88 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py @@ -1,13 +1,13 @@ """Record used to represent a service block of an out of band invitation.""" -from typing import Sequence +from typing import Optional, Sequence from marshmallow import EXCLUDE, fields, post_dump from .....messaging.models.base import BaseModel, BaseModelSchema from .....messaging.valid import ( - DID_KEY_EXAMPLE, - DID_KEY_VALIDATE, + DID_KEY_OR_REF_EXAMPLE, + DID_KEY_OR_REF_VALIDATE, INDY_DID_EXAMPLE, INDY_DID_VALIDATE, ) @@ -24,12 +24,12 @@ class Meta: def __init__( self, *, - _id: str = None, - _type: str = None, - did: str = None, - recipient_keys: Sequence[str] = None, - routing_keys: Sequence[str] = None, - service_endpoint: str = None, + _id: Optional[str] = None, + _type: Optional[str] = None, + did: Optional[str] = None, + recipient_keys: Optional[Sequence[str]] = None, + routing_keys: Optional[Sequence[str]] = None, + service_endpoint: Optional[str] = None, ): """Initialize a Service instance. @@ -72,10 +72,10 @@ class Meta: recipient_keys = fields.List( fields.Str( - validate=DID_KEY_VALIDATE, + validate=DID_KEY_OR_REF_VALIDATE, metadata={ "description": "Recipient public key", - "example": DID_KEY_EXAMPLE, + "example": DID_KEY_OR_REF_EXAMPLE, }, ), data_key="recipientKeys", @@ -85,8 +85,8 @@ class Meta: routing_keys = fields.List( fields.Str( - validate=DID_KEY_VALIDATE, - metadata={"description": "Routing key", "example": DID_KEY_EXAMPLE}, + validate=DID_KEY_OR_REF_VALIDATE, + metadata={"description": "Routing key", "example": DID_KEY_OR_REF_EXAMPLE}, ), data_key="routingKeys", required=False, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index 4f33cbbbe9..b37cbd102b 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -108,7 +108,10 @@ def test_url_round_trip(self): service = Service( _id="#inline", _type=DID_COMM, - recipient_keys=[DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).did], + recipient_keys=[ + DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).did, + DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).key_id, + ], service_endpoint="http://1.2.3.4:8080/service", ) invi_msg = InvitationMessage( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index 7bee844ce4..25abf93193 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -808,7 +808,7 @@ async def test_create_invitation_peer_did(self): service["routingKeys"][0] == DIDKey.from_public_key_b58( self.test_mediator_routing_keys[0], ED25519 - ).did + ).key_id ) assert service["serviceEndpoint"] == self.test_mediator_endpoint diff --git a/aries_cloudagent/resolver/default/peer3.py b/aries_cloudagent/resolver/default/peer3.py index bfe79ce15c..4f97716462 100644 --- a/aries_cloudagent/resolver/default/peer3.py +++ b/aries_cloudagent/resolver/default/peer3.py @@ -6,27 +6,22 @@ the did:peer:2 has been replaced with the did:peer:3. """ -import re from copy import deepcopy from hashlib import sha256 +import re from typing import Optional, Pattern, Sequence, Text -from multiformats import multibase, multicodec - -from peerdid.dids import ( - DID, - MalformedPeerDIDError, - DIDDocument, -) -from peerdid.keys import to_multibase, MultibaseFormat -from ...wallet.util import bytes_to_b58 -from ...connections.base_manager import BaseConnectionManager +from peerdid.dids import DID, DIDDocument, MalformedPeerDIDError +from peerdid.keys import MultibaseFormat, to_multibase + from ...config.injection_context import InjectionContext +from ...connections.base_manager import BaseConnectionManager from ...core.profile import Profile from ...storage.base import BaseStorage from ...storage.error import StorageNotFoundError from ...storage.record import StorageRecord - +from ...utils.multiformats import multibase, multicodec +from ...wallet.util import bytes_to_b58 from ..base import BaseDIDResolver, DIDNotFound, ResolverType RECORD_TYPE_DID_DOCUMENT = "did_document" # pydid DIDDocument diff --git a/aries_cloudagent/storage/in_memory.py b/aries_cloudagent/storage/in_memory.py index 84a9ed241f..d2b0671f4f 100644 --- a/aries_cloudagent/storage/in_memory.py +++ b/aries_cloudagent/storage/in_memory.py @@ -255,6 +255,7 @@ def __init__( self.page_size = page_size or DEFAULT_PAGE_SIZE self.tag_query = tag_query self.type_filter = type_filter + self._done = False async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: """Fetch the next list of results from the store. @@ -270,7 +271,7 @@ async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: StorageSearchError: If the search query has not been opened """ - if self._cache is None: + if self._cache is None and self._done: raise StorageSearchError("Search query is complete") ret = [] @@ -291,9 +292,11 @@ async def fetch(self, max_count: int = None) -> Sequence[StorageRecord]: if not ret: self._cache = None + self._done = True return ret async def close(self): """Dispose of the search query.""" self._cache = None + self._done = True diff --git a/aries_cloudagent/utils/multiformats/__init__.py b/aries_cloudagent/utils/multiformats/__init__.py new file mode 100644 index 0000000000..b3aec34c53 --- /dev/null +++ b/aries_cloudagent/utils/multiformats/__init__.py @@ -0,0 +1 @@ +"""Multiformats utility functions.""" diff --git a/aries_cloudagent/utils/multiformats/multibase.py b/aries_cloudagent/utils/multiformats/multibase.py new file mode 100644 index 0000000000..5f9ee5c0ff --- /dev/null +++ b/aries_cloudagent/utils/multiformats/multibase.py @@ -0,0 +1,104 @@ +"""MultiBase encoding and decoding utilities.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import ClassVar, Literal, Union + + +class MultibaseEncoder(ABC): + """Encoding details.""" + + name: ClassVar[str] + character: ClassVar[str] + + @abstractmethod + def encode(self, value: bytes) -> str: + """Encode a byte string using this encoding.""" + + @abstractmethod + def decode(self, value: str) -> bytes: + """Decode a string using this encoding.""" + + +class Base58BtcEncoder(MultibaseEncoder): + """Base58BTC encoding.""" + + name = "base58btc" + character = "z" + + def encode(self, value: bytes) -> str: + """Encode a byte string using the base58btc encoding.""" + import base58 + + return base58.b58encode(value).decode() + + def decode(self, value: str) -> bytes: + """Decode a multibase encoded string.""" + import base58 + + return base58.b58decode(value) + + +class Encoding(Enum): + """Enum for supported encodings.""" + + base58btc = Base58BtcEncoder() + # Insert additional encodings here + + @classmethod + def from_name(cls, name: str) -> MultibaseEncoder: + """Get encoding from name.""" + for encoding in cls: + if encoding.value.name == name: + return encoding.value + raise ValueError(f"Unsupported encoding: {name}") + + @classmethod + def from_character(cls, character: str) -> MultibaseEncoder: + """Get encoding from character.""" + for encoding in cls: + if encoding.value.character == character: + return encoding.value + raise ValueError(f"Unsupported encoding: {character}") + + +EncodingStr = Literal[ + "base58btc", + # Insert additional encoding names here +] + + +def encode(value: bytes, encoding: Union[Encoding, EncodingStr]) -> str: + """Encode a byte string using the given encoding. + + Args: + value: The byte string to encode + encoding: The encoding to use + + Returns: + The encoded string + """ + if isinstance(encoding, str): + encoder = Encoding.from_name(encoding) + elif isinstance(encoding, Encoding): + encoder = encoding.value + else: + raise TypeError("encoding must be an Encoding or EncodingStr") + + return encoder.character + encoder.encode(value) + + +def decode(value: str) -> bytes: + """Decode a multibase encoded string. + + Args: + value: The string to decode + + Returns: + The decoded byte string + """ + encoding = value[0] + encoded = value[1:] + encoder = Encoding.from_character(encoding) + + return encoder.decode(encoded) diff --git a/aries_cloudagent/utils/multiformats/multicodec.py b/aries_cloudagent/utils/multiformats/multicodec.py new file mode 100644 index 0000000000..465d5b3ea2 --- /dev/null +++ b/aries_cloudagent/utils/multiformats/multicodec.py @@ -0,0 +1,72 @@ +"""Multicodec wrap and unwrap functions.""" + +from enum import Enum +from typing import Literal, NamedTuple, Optional, Union + + +class Multicodec(NamedTuple): + """Multicodec base class.""" + + name: str + code: bytes + + +class SupportedCodecs(Enum): + """Enumeration of supported multicodecs.""" + + ed25519_pub = Multicodec("ed25519-pub", b"\xed\x01") + x25519_pub = Multicodec("x25519-pub", b"\xec\x01") + bls12381g1 = Multicodec("bls12_381-g1-pub", b"\xea\x01") + bls12381g2 = Multicodec("bls12_381-g2-pub", b"\xeb\x01") + bls12381g1g2 = Multicodec("bls12_381-g1g2-pub", b"\xee\x01") + secp256k1_pub = Multicodec("secp256k1-pub", b"\xe7\x01") + + @classmethod + def by_name(cls, name: str) -> Multicodec: + """Get multicodec by name.""" + for codec in cls: + if codec.value.name == name: + return codec.value + raise ValueError(f"Unsupported multicodec: {name}") + + @classmethod + def for_data(cls, data: bytes) -> Multicodec: + """Get multicodec by data.""" + for codec in cls: + if data.startswith(codec.value.code): + return codec.value + raise ValueError("Unsupported multicodec") + + +MulticodecStr = Literal[ + "ed25519-pub", + "x25519-pub", + "bls12_381-g1-pub", + "bls12_381-g2-pub", + "bls12_381-g1g2-pub", + "secp256k1-pub", +] + + +def multicodec(name: str) -> Multicodec: + """Get multicodec by name.""" + return SupportedCodecs.by_name(name) + + +def wrap(multicodec: Union[Multicodec, MulticodecStr], data: bytes) -> bytes: + """Wrap data with multicodec prefix.""" + if isinstance(multicodec, str): + multicodec = SupportedCodecs.by_name(multicodec) + elif isinstance(multicodec, Multicodec): + pass + else: + raise TypeError("multicodec must be Multicodec or MulticodecStr") + + return multicodec.code + data + + +def unwrap(data: bytes, codec: Optional[Multicodec] = None) -> tuple[Multicodec, bytes]: + """Unwrap data with multicodec prefix.""" + if not codec: + codec = SupportedCodecs.for_data(data) + return codec, data[len(codec.code) :] diff --git a/aries_cloudagent/utils/tests/test_multiformats.py b/aries_cloudagent/utils/tests/test_multiformats.py new file mode 100644 index 0000000000..5ef8ce4308 --- /dev/null +++ b/aries_cloudagent/utils/tests/test_multiformats.py @@ -0,0 +1,73 @@ +import pytest +from ..multiformats import multibase, multicodec + + +def test_encode_decode(): + value = b"Hello World!" + encoded = multibase.encode(value, "base58btc") + assert encoded == "z2NEpo7TZRRrLZSi2U" + decoded = multibase.decode(encoded) + assert decoded == value + + +def test_encode_decode_by_encoding(): + value = b"Hello World!" + encoded = multibase.encode(value, multibase.Encoding.base58btc) + assert encoded == "z2NEpo7TZRRrLZSi2U" + decoded = multibase.decode(encoded) + assert decoded == value + + +def test_x_unknown_encoding(): + with pytest.raises(ValueError): + multibase.encode(b"Hello World!", "fancy-encoding") + + +def test_x_unknown_character(): + with pytest.raises(ValueError): + multibase.decode("fHello World!") + + +def test_x_invalid_encoding(): + with pytest.raises(TypeError): + multibase.encode(b"Hello World!", 123) + + +def test_wrap_unwrap(): + value = b"Hello World!" + wrapped = multicodec.wrap("ed25519-pub", value) + codec, unwrapped = multicodec.unwrap(wrapped) + assert codec == multicodec.multicodec("ed25519-pub") + assert unwrapped == value + + +def test_wrap_unwrap_custom(): + value = b"Hello World!" + my_codec = multicodec.Multicodec("my-codec", b"\x00\x01") + wrapped = multicodec.wrap(my_codec, value) + codec, unwrapped = multicodec.unwrap(wrapped, my_codec) + assert codec == my_codec + assert unwrapped == value + + +def test_wrap_unwrap_by_codec(): + value = b"Hello World!" + wrapped = multicodec.wrap(multicodec.multicodec("ed25519-pub"), value) + codec, unwrapped = multicodec.unwrap(wrapped, multicodec.multicodec("ed25519-pub")) + assert codec == multicodec.multicodec("ed25519-pub") + assert unwrapped == value + + +def test_x_unknown_multicodec(): + with pytest.raises(ValueError): + multicodec.wrap("fancy-multicodec", b"Hello World!") + + +def test_x_invalid_multicodec(): + with pytest.raises(TypeError): + multicodec.wrap(123, b"Hello World!") + + +def test_x_invalid_multicodec_unwrap(): + with pytest.raises(ValueError): + multicodec.unwrap(b"Hello World!") diff --git a/aries_cloudagent/vc/ld_proofs/suites/ed25519_signature_2020.py b/aries_cloudagent/vc/ld_proofs/suites/ed25519_signature_2020.py index b1b823d5c9..fee9c89084 100644 --- a/aries_cloudagent/vc/ld_proofs/suites/ed25519_signature_2020.py +++ b/aries_cloudagent/vc/ld_proofs/suites/ed25519_signature_2020.py @@ -1,14 +1,13 @@ """Ed25519Signature2018 suite.""" from datetime import datetime -from typing import Union, List +from typing import List, Union -from multiformats import multibase - -from .linked_data_signature import LinkedDataSignature +from ....utils.multiformats import multibase from ..crypto import _KeyPair as KeyPair from ..document_loader import DocumentLoaderMethod from ..error import LinkedDataProofException +from .linked_data_signature import LinkedDataSignature class Ed25519Signature2020(LinkedDataSignature): diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py index ba5e7d4bd4..58903a2963 100644 --- a/aries_cloudagent/wallet/routes.py +++ b/aries_cloudagent/wallet/routes.py @@ -671,7 +671,6 @@ async def wallet_set_public_did(request: web.BaseRequest): routing_keys, mediator_endpoint = await route_manager.routing_info( profile, - None, mediation_record, ) diff --git a/docker/Dockerfile.run b/docker/Dockerfile.run index 8660e45d76..ad4fde5622 100644 --- a/docker/Dockerfile.run +++ b/docker/Dockerfile.run @@ -7,14 +7,12 @@ RUN apt-get update && apt-get install -y curl && apt-get clean RUN pip install --no-cache-dir poetry -ADD . . +RUN mkdir -p aries_cloudagent && touch aries_cloudagent/__init__.py +ADD pyproject.toml poetry.lock README.md ./ +RUN mkdir -p logs && chmod -R ug+rw logs RUN poetry install -E "askar bbs" -RUN mkdir -p aries_cloudagent && touch aries_cloudagent/__init__.py -ADD aries_cloudagent/version.py aries_cloudagent/version.py - -RUN mkdir -p logs && chmod -R ug+rw logs -ADD aries_cloudagent ./aries_cloudagent +ADD . . ENTRYPOINT ["/bin/bash", "-c", "poetry run aca-py \"$@\"", "--"] diff --git a/poetry.lock b/poetry.lock index 37ee86ba5d..3cebc6e7c3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -265,24 +265,6 @@ files = [ [package.extras] tests = ["PyHamcrest (>=2.0.2)", "mypy", "pytest (>=4.6)", "pytest-benchmark", "pytest-cov", "pytest-flake8"] -[[package]] -name = "bases" -version = "0.2.1" -description = "Python library for general Base-N encodings." -optional = false -python-versions = ">=3.7" -files = [ - {file = "bases-0.2.1-py3-none-any.whl", hash = "sha256:d030b5e349773ad2a067bfaaf3a9794b70d23a1f923033c15c2e0ce869854f6d"}, - {file = "bases-0.2.1.tar.gz", hash = "sha256:b0999e14725b59bff38974b00e918629e0e29f3d80a40e022c6f0f8d5cdff9d4"}, -] - -[package.dependencies] -typing-extensions = "*" -typing-validation = "*" - -[package.extras] -dev = ["base58", "mypy", "pylint", "pytest", "pytest-cov"] - [[package]] name = "black" version = "23.7.0" @@ -1528,44 +1510,6 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] -[[package]] -name = "multiformats" -version = "0.2.1" -description = "Python implementation of multiformats protocols." -optional = false -python-versions = ">=3.7" -files = [ - {file = "multiformats-0.2.1-py3-none-any.whl", hash = "sha256:0655fad05cff4cb9eaae3a0f61c4cf8189857fef5a5d0ade6aa25d6b8439e204"}, - {file = "multiformats-0.2.1.tar.gz", hash = "sha256:4aee6eb5289c3cd00315e3f2b97e36b60db6af9bcec91d65f32d7155942dbef9"}, -] - -[package.dependencies] -bases = "*" -multiformats-config = "*" -typing-extensions = "*" -typing-validation = "*" - -[package.extras] -dev = ["blake3", "mmh3", "mypy", "pycryptodomex", "pylint", "pysha3", "pyskein", "pytest", "pytest-cov"] -full = ["blake3", "mmh3", "pycryptodomex", "pysha3", "pyskein"] - -[[package]] -name = "multiformats-config" -version = "0.2.0.post4" -description = "Pre-loading configuration module for the 'multiformats' package." -optional = false -python-versions = ">=3.7" -files = [ - {file = "multiformats-config-0.2.0.post4.tar.gz", hash = "sha256:3b5d0be63211681edbcf887abe149fc17b43a79d0c009fb44232af9addf7a309"}, - {file = "multiformats_config-0.2.0.post4-py3-none-any.whl", hash = "sha256:221a630732f7bb2cd6cb708b94a85da683e29618e4a2055eea1bf4f980feefce"}, -] - -[package.dependencies] -multiformats = "*" - -[package.extras] -dev = ["mypy", "pylint", "pytest", "pytest-cov"] - [[package]] name = "mypy-extensions" version = "1.0.0" @@ -2572,20 +2516,6 @@ files = [ {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, ] -[[package]] -name = "typing-validation" -version = "1.0.0.post2" -description = "A simple library for runtime type-checking." -optional = false -python-versions = ">=3.7" -files = [ - {file = "typing-validation-1.0.0.post2.tar.gz", hash = "sha256:6a30dec74373f9dca29db6f79ef65eb765a6934c09d87639cf422288933b2aa4"}, - {file = "typing_validation-1.0.0.post2-py3-none-any.whl", hash = "sha256:c9f5cb42435ee59fcf5a1a69dc88ccd5dc6f904436e61b5b8c276906a1c9e454"}, -] - -[package.extras] -dev = ["mypy", "pylint", "pytest", "pytest-cov", "rich"] - [[package]] name = "unflatten" version = "0.1.1" @@ -2599,13 +2529,13 @@ files = [ [[package]] name = "urllib3" -version = "2.0.6" +version = "2.0.7" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.7" files = [ - {file = "urllib3-2.0.6-py3-none-any.whl", hash = "sha256:7a7c7003b000adf9e7ca2a377c9688bbc54ed41b985789ed576570342a375cd2"}, - {file = "urllib3-2.0.6.tar.gz", hash = "sha256:b19e1a85d206b56d7df1d5e683df4a7725252a964e3993648dd0fb5a1c157564"}, + {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, + {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, ] [package.extras] @@ -2869,4 +2799,4 @@ indy = ["python3-indy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2192fc079ef7dbcf69ea617b81e6d1f74328ac130d2979d9d51cc4ab7561b375" +content-hash = "98923c33a560d59e549895d86258aa5b0228ea7b08bd387586612bff5825a6e2" diff --git a/pyproject.toml b/pyproject.toml index 94f6685fcf..e53446ec24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ jsonpath_ng="1.5.2" Markdown="~3.1.1" markupsafe="2.0.1" marshmallow="~3.20.1" -multiformats="~0.2.1" nest_asyncio="~1.5.5" packaging="~23.1" portalocker="~2.7.0"