diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 038ab3fb1f..05c0bc72d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,3 +19,6 @@ repos: - id: ruff stages: [commit] args: [--fix, --exit-non-zero-on-fix] + # Run the formatter + - id: ruff-format + stages: [commit] diff --git a/aries_cloudagent/admin/request_context.py b/aries_cloudagent/admin/request_context.py index 159334c685..215a64f3bb 100644 --- a/aries_cloudagent/admin/request_context.py +++ b/aries_cloudagent/admin/request_context.py @@ -21,13 +21,13 @@ def __init__( self, profile: Profile, *, - context: InjectionContext = None, - settings: Mapping[str, object] = None, - root_profile: Profile = None, - metadata: dict = None + context: Optional[InjectionContext] = None, + settings: Optional[Mapping[str, object]] = None, + root_profile: Optional[Profile] = None, + metadata: Optional[dict] = None ): """Initialize an instance of AdminRequestContext.""" - self._context = (context or profile.context).start_scope("admin", settings) + self._context = (context or profile.context).start_scope(settings) self._profile = profile self._root_profile = root_profile self._metadata = metadata @@ -72,7 +72,7 @@ def transaction(self) -> ProfileSession: def inject( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, ) -> InjectType: """Get the provided instance of a given class identifier. @@ -89,7 +89,7 @@ def inject( def inject_or( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, default: Optional[InjectType] = None, ) -> Optional[InjectType]: """Get the provided instance of a given class identifier or default if not found. @@ -111,7 +111,7 @@ def update_settings(self, settings: Mapping[str, object]): @classmethod def test_context( - cls, session_inject: dict = None, profile: Profile = None + cls, session_inject: Optional[dict] = None, profile: Optional[Profile] = None ) -> "AdminRequestContext": """Quickly set up a new admin request context for tests.""" ctx = AdminRequestContext(profile or IN_MEM.resolved.test_profile()) diff --git a/aries_cloudagent/askar/profile.py b/aries_cloudagent/askar/profile.py index 27cec91b7c..b07dcf1f6b 100644 --- a/aries_cloudagent/askar/profile.py +++ b/aries_cloudagent/askar/profile.py @@ -114,7 +114,7 @@ def bind_providers(self): "aries_cloudagent.indy.credx.issuer.IndyCredxIssuer", ref(self) ), ) - injector.bind_provider( + injector.soft_bind_provider( VCHolder, ClassProvider( "aries_cloudagent.storage.vc_holder.askar.AskarVCHolder", diff --git a/aries_cloudagent/config/injection_context.py b/aries_cloudagent/config/injection_context.py index fdadc88224..bdd91de149 100644 --- a/aries_cloudagent/config/injection_context.py +++ b/aries_cloudagent/config/injection_context.py @@ -21,12 +21,14 @@ class InjectionContext(BaseInjector): ROOT_SCOPE = "application" def __init__( - self, *, settings: Mapping[str, object] = None, enforce_typing: bool = True + self, + *, + settings: Optional[Mapping[str, object]] = None, + enforce_typing: bool = True ): """Initialize a `ServiceConfig`.""" self._injector = Injector(settings, enforce_typing=enforce_typing) self._scope_name = InjectionContext.ROOT_SCOPE - self._scopes = [] @property def injector(self) -> Injector: @@ -38,16 +40,6 @@ def injector(self, injector: Injector): """Setter for scope-specific injector.""" self._injector = injector - @property - def scope_name(self) -> str: - """Accessor for the current scope name.""" - return self._scope_name - - @scope_name.setter - def scope_name(self, scope_name: str): - """Accessor for the current scope name.""" - self._scope_name = scope_name - @property def settings(self) -> Settings: """Accessor for scope-specific settings.""" @@ -64,7 +56,7 @@ def update_settings(self, settings: Mapping[str, object]): self.injector.settings.update(settings) def start_scope( - self, scope_name: str, settings: Optional[Mapping[str, object]] = None + self, settings: Optional[Mapping[str, object]] = None ) -> "InjectionContext": """Begin a new named scope. @@ -76,39 +68,15 @@ def start_scope( A new injection context representing the scope """ - if not scope_name: - raise InjectionContextError("Scope name must be non-empty") - if self._scope_name == scope_name: - raise InjectionContextError("Cannot re-enter scope: {}".format(scope_name)) - for scope in self._scopes: - if scope.name == scope_name: - raise InjectionContextError( - "Cannot re-enter scope: {}".format(scope_name) - ) result = self.copy() - result._scopes.append(Scope(name=self.scope_name, injector=self.injector)) - result._scope_name = scope_name if settings: result.update_settings(settings) return result - def injector_for_scope(self, scope_name: str) -> Injector: - """Fetch the injector for a specific scope. - - Args: - scope_name: The unique scope identifier - """ - if scope_name == self.scope_name: - return self.injector - for scope in self._scopes: - if scope.name == scope_name: - return scope.injector - return None - def inject( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, ) -> InjectType: """Get the provided instance of a given class identifier. @@ -125,7 +93,7 @@ def inject( def inject_or( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, default: Optional[InjectType] = None, ) -> Optional[InjectType]: """Get the provided instance of a given class identifier or default if not found. @@ -145,5 +113,4 @@ def copy(self) -> "InjectionContext": """Produce a copy of the injector instance.""" result = copy.copy(self) result._injector = self.injector.copy() - result._scopes = self._scopes.copy() return result diff --git a/aries_cloudagent/config/injector.py b/aries_cloudagent/config/injector.py index 4d99f8d09c..26130623db 100644 --- a/aries_cloudagent/config/injector.py +++ b/aries_cloudagent/config/injector.py @@ -1,6 +1,6 @@ """Standard Injector implementation.""" -from typing import Mapping, Optional, Type +from typing import Dict, Mapping, Optional, Type from .base import BaseProvider, BaseInjector, InjectionError, InjectType from .provider import InstanceProvider, CachedProvider @@ -11,11 +11,14 @@ class Injector(BaseInjector): """Injector implementation with static and dynamic bindings.""" def __init__( - self, settings: Mapping[str, object] = None, *, enforce_typing: bool = True + self, + settings: Optional[Mapping[str, object]] = None, + *, + enforce_typing: bool = True, ): """Initialize an `Injector`.""" self.enforce_typing = enforce_typing - self._providers = {} + self._providers: Dict[Type, BaseProvider] = {} self._settings = Settings(settings) @property @@ -42,6 +45,24 @@ def bind_provider( provider = CachedProvider(provider) self._providers[base_cls] = provider + def soft_bind_instance(self, base_cls: Type[InjectType], instance: InjectType): + """Add a static instance as a soft class binding. + + The binding occurs only if a provider for the same type does not already exist. + """ + if not self.get_provider(base_cls): + self.bind_instance(base_cls, instance) + + def soft_bind_provider( + self, base_cls: Type[InjectType], provider: BaseProvider, *, cache: bool = False + ): + """Add a dynamic instance resolver as a soft class binding. + + The binding occurs only if a provider for the same type does not already exist. + """ + if not self.get_provider(base_cls): + self.bind_provider(base_cls, provider, cache=cache) + def clear_binding(self, base_cls: Type[InjectType]): """Remove a previously-added binding.""" if base_cls in self._providers: @@ -54,7 +75,7 @@ def get_provider(self, base_cls: Type[InjectType]): def inject_or( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, default: Optional[InjectType] = None, ) -> Optional[InjectType]: """Get the provided instance of a given class identifier or default if not found. @@ -92,7 +113,7 @@ def inject_or( def inject( self, base_cls: Type[InjectType], - settings: Mapping[str, object] = None, + settings: Optional[Mapping[str, object]] = None, ) -> InjectType: """Get the provided instance of a given class identifier. diff --git a/aries_cloudagent/config/tests/test_injection_context.py b/aries_cloudagent/config/tests/test_injection_context.py index e6bd1fd4f0..f68c7f6948 100644 --- a/aries_cloudagent/config/tests/test_injection_context.py +++ b/aries_cloudagent/config/tests/test_injection_context.py @@ -1,7 +1,7 @@ from unittest import IsolatedAsyncioTestCase from ..base import InjectionError -from ..injection_context import InjectionContext, InjectionContextError +from ..injection_context import InjectionContext class TestInjectionContext(IsolatedAsyncioTestCase): @@ -14,39 +14,16 @@ def setUp(self): def test_settings_init(self): """Test settings initialization.""" - assert self.test_instance.scope_name == self.test_instance.ROOT_SCOPE for key in self.test_settings: assert key in self.test_instance.settings assert self.test_instance.settings[key] == self.test_settings[key] - def test_simple_scope(self): - """Test scope entrance and exit.""" - with self.assertRaises(InjectionContextError): - self.test_instance.start_scope(None) - with self.assertRaises(InjectionContextError): - self.test_instance.start_scope(self.test_instance.ROOT_SCOPE) - - injector = self.test_instance.injector_for_scope(self.test_instance.ROOT_SCOPE) - assert injector == self.test_instance.injector - assert self.test_instance.injector_for_scope("no such scope") is None - - context = self.test_instance.start_scope(self.test_scope) - assert context.scope_name == self.test_scope - context.scope_name = "Bob" - assert context.scope_name == "Bob" - - with self.assertRaises(InjectionContextError): - context.start_scope(self.test_instance.ROOT_SCOPE) - assert self.test_instance.scope_name == self.test_instance.ROOT_SCOPE - def test_settings_scope(self): """Test scoped settings.""" upd_settings = {self.test_key: "NEWVAL"} - context = self.test_instance.start_scope(self.test_scope, upd_settings) + context = self.test_instance.start_scope(upd_settings) assert context.settings[self.test_key] == "NEWVAL" assert self.test_instance.settings[self.test_key] == self.test_value - root = context.injector_for_scope(context.ROOT_SCOPE) - assert root.settings[self.test_key] == self.test_value context.settings = upd_settings assert context.settings == upd_settings @@ -64,11 +41,8 @@ async def test_inject_simple(self): async def test_inject_scope(self): """Test a scoped injection.""" - context = self.test_instance.start_scope(self.test_scope) + context = self.test_instance.start_scope() assert context.inject_or(str) is None context.injector.bind_instance(str, self.test_value) assert context.inject(str) is self.test_value assert self.test_instance.inject_or(str) is None - root = context.injector_for_scope(context.ROOT_SCOPE) - assert root.inject_or(str) is None - assert self.test_instance.inject_or(str) is None diff --git a/aries_cloudagent/config/tests/test_injector.py b/aries_cloudagent/config/tests/test_injector.py index 76da5f7992..3b5023307f 100644 --- a/aries_cloudagent/config/tests/test_injector.py +++ b/aries_cloudagent/config/tests/test_injector.py @@ -70,6 +70,39 @@ def test_inject_provider(self): assert mock_provider.settings[self.test_key] == override_settings[self.test_key] assert mock_provider.injector is self.test_instance + def test_inject_soft_provider_bindings(self): + """Test injecting providers with soft binding.""" + provider = MockProvider(self.test_value) + override = MockProvider("Override") + + self.test_instance.soft_bind_provider(str, provider) + assert self.test_instance.inject(str) == self.test_value + + self.test_instance.clear_binding(str) + # Bound by a plugin on startup, for example + self.test_instance.bind_provider(str, override) + + # Bound later in Profile.bind_providerse + self.test_instance.soft_bind_provider(str, provider) + + # We want the plugin value, not the Profile bound value + assert self.test_instance.inject(str) == "Override" + + def test_inject_soft_instance_bindings(self): + """Test injecting providers with soft binding.""" + self.test_instance.soft_bind_instance(str, self.test_value) + assert self.test_instance.inject(str) == self.test_value + + self.test_instance.clear_binding(str) + # Bound by a plugin on startup, for example + self.test_instance.bind_instance(str, "Override") + + # Bound later in Profile.bind_providerse + self.test_instance.soft_bind_instance(str, self.test_value) + + # We want the plugin value, not the Profile bound value + assert self.test_instance.inject(str) == "Override" + def test_bad_provider(self): """Test empty and invalid provider results.""" self.test_instance.bind_provider(str, MockProvider(None)) diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index a35a44bfaf..8d2796eeea 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -390,8 +390,6 @@ async def find_did_for_key(self, key: str) -> str: storage: BaseStorage = session.inject(BaseStorage) record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key}) ret_did = record.tags["did"] - if ret_did.startswith("did:peer:4"): - ret_did = self.long_did_peer_to_short(ret_did) return ret_did async def remove_keys_for_did(self, did: str): @@ -418,9 +416,7 @@ async def resolve_didcomm_services( doc_dict: dict = await resolver.resolve(self._profile, did, service_accept) doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True) except ResolverError as error: - raise BaseConnectionManagerError( - "Failed to resolve DID services" - ) from error + raise BaseConnectionManagerError("Failed to resolve DID services") from error if not doc.service: raise BaseConnectionManagerError( @@ -480,10 +476,7 @@ async def resolve_invitation( return ( endpoint, - [ - self._extract_key_material_in_base58_format(key) - for key in recipient_keys - ], + [self._extract_key_material_in_base58_format(key) for key in recipient_keys], [self._extract_key_material_in_base58_format(key) for key in routing_keys], ) @@ -752,9 +745,7 @@ async def get_connection_targets( async with cache.acquire(cache_key) as entry: if entry.result: self._logger.debug("Connection targets retrieved from cache") - targets = [ - ConnectionTarget.deserialize(row) for row in entry.result - ] + targets = [ConnectionTarget.deserialize(row) for row in entry.result] else: if not connection: async with self._profile.session() as session: @@ -769,9 +760,7 @@ async def get_connection_targets( # Otherwise, a replica that participated early in exchange # may have bad data set in cache. self._logger.debug("Caching connection targets") - await entry.set_result( - [row.serialize() for row in targets], 3600 - ) + await entry.set_result([row.serialize() for row in targets], 3600) else: self._logger.debug( "Not caching connection targets for connection in " @@ -830,12 +819,8 @@ def diddoc_connection_targets( did=doc.did, endpoint=service.endpoint, label=their_label, - recipient_keys=[ - key.value for key in (service.recip_keys or ()) - ], - routing_keys=[ - key.value for key in (service.routing_keys or ()) - ], + recipient_keys=[key.value for key in (service.recip_keys or ())], + routing_keys=[key.value for key in (service.routing_keys or ())], sender_key=sender_verkey, ) ) @@ -872,7 +857,18 @@ async def find_connection( """ connection = None - if their_did: + if their_did and their_did.startswith("did:peer:4"): + # did:peer:4 always recorded as long + long = their_did + short = self.long_did_peer_to_short(their_did) + try: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_did_peer_4( + session, long, short, my_did + ) + except StorageNotFoundError: + pass + elif their_did: try: async with self._profile.session() as session: connection = await ConnRecord.retrieve_by_did( diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index 116a91568f..9817c18025 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -215,7 +215,9 @@ def __init__( self.their_role = ( ConnRecord.Role.get(their_role).rfc160 if isinstance(their_role, str) - else None if their_role is None else their_role.rfc160 + else None + if their_role is None + else their_role.rfc160 ) self.invitation_key = invitation_key self.invitation_msg_id = invitation_msg_id @@ -290,6 +292,44 @@ async def retrieve_by_did( return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + @classmethod + async def retrieve_by_did_peer_4( + cls, + session: ProfileSession, + their_did_long: Optional[str] = None, + their_did_short: Optional[str] = None, + my_did: Optional[str] = None, + their_role: Optional[str] = None, + ) -> "ConnRecord": + """Retrieve a connection record by target DID. + + Args: + session: The active profile session + their_did_long: The target DID to filter by, in long form + their_did_short: The target DID to filter by, in short form + my_did: One of our DIDs to filter by + my_role: Filter connections by their role + their_role: Filter connections by their role + """ + tag_filter = {} + if their_did_long and their_did_short: + tag_filter["$or"] = [ + {"their_did": their_did_long}, + {"their_did": their_did_short}, + ] + elif their_did_short: + tag_filter["their_did"] = their_did_short + elif their_did_long: + tag_filter["their_did"] = their_did_long + if my_did: + tag_filter["my_did"] = my_did + + post_filter = {} + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + @classmethod async def retrieve_by_invitation_key( cls, session: ProfileSession, invitation_key: str, their_role: str = None @@ -371,9 +411,7 @@ async def retrieve_by_request_id( return await cls.retrieve_by_tag_filter(session, tag_filter) @classmethod - async def retrieve_by_alias( - cls, session: ProfileSession, alias: str - ) -> "ConnRecord": + async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord": """Retrieve a connection record from an alias. Args: diff --git a/aries_cloudagent/connections/models/tests/test_conn_record.py b/aries_cloudagent/connections/models/tests/test_conn_record.py index 125330c70a..b396a9b111 100644 --- a/aries_cloudagent/connections/models/tests/test_conn_record.py +++ b/aries_cloudagent/connections/models/tests/test_conn_record.py @@ -25,6 +25,15 @@ def setUp(self): self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP" self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" + self.test_did_peer_4_a = "did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr:z2pfttj3xn6tJ7wpHV9ZSwpQVMNtHC7EtM36r1mC5fDpZ25882Yitk21QPbqzuefKPrbFsexWmQtE78vYWweckXtKeu5BhuFDvCjMUf8SC5z7cMPvp8SCdcbqWnHxygjBH9zAAs9myGRnZYXuAkq6CfBdn6ZiNmdRf65TdVfE3cYfS4jNzVZDs1abwytn4jdFJ2fwVegPB3vLY8XxeUEx12a4rtjkqMhs6zBQbJvc4PVUM9rvMbPM2QeXDy7ovkkHaKLUbNUxjQrcQeiR8MTLe1iaVtUv6RpBf4z7ioqfa4VDRmAZT7isVM3NvENUceeUfDZoFbM8PZqGkCbFvfoKiK3SrmTsvPtpXaBAfR4z7w18cFjsvvLBNMZbPnARn4oZijCkYwgaNmAUthgDP4XBFetdUo8728w25FUwTWjAPc1BdSSWPWMRKwCqyAP1Q1hM8dU6otT27MQaQ1rozKncn3U48CXEi2Ef26EDBrSozEWR273ancFojNXBbZVghZG5b6xdypjQir9PgTF94dsygtu47hNxQweVKLUM1p9umqHLhjvLhpS1aGQkGZNnKUHjLDHdToigo15F7TAf8RfMaducHBThFzEp9TUJmiZFTUYQ1uaBgSPMSaWnvTfUoFmLoGbdrWj1vVEsRrARq37u1SJGLqBx7FM2SUd8nxPsChP5jY8ka8F8r7j8qZLHZqvUXbynPUsViwwdFFk8SCsBWfiQgvq7sRiTdLnYv3H5DSwA1uW2GNYXGgkT9aJza4Sk1gvag5iAbQZgxbU594enjVSTjiWsFw2oYQ75JJwiSEgsP2rhpGsNhXxfECNLUtb7FQbDQPtUvLHCJATf7QXJEoWjpfAywmB6NyQcXfskco6FKJNNHeZBnST6U1meH98Ku66vha1k8hAc72iBhXQBnWUjaGRyzELsh2LkBH2UNwW9TuFhxz3SKtL5pGShVQ5XGQhmdrkWP68d6h7c1JqsfogcDBnmWS4VSbJwgtsPNTSsTHGX8hpGvg" + self.test_did_peer_4_short_a = ( + "did:peer:4zQmV3Hf1TT4Xn73MBVf2NAWdMwrzUabpEvwtV3RoZc17Vxr" + ) + self.test_did_peer_4_b = "did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw:z7p4QX8zEXt2sMjv1Tqq8Lv8Nx8oGo2uRczBe21vyfMhQzsWDnwGmjriYfUX75WDq622czcdHjWGhh2VTbzKhLXUjY8Ma7g64dKAVcy8SaxN5QVdjwpXgD7htKCgCjah8jHEzyBZFrtdfTHiVXfSUz1BiURQf1Z3NfxW5cWYsvDJVvQzVmdHb8ekzCnvxCqL2UV1v9SBb1DsU66N3PCp9HVpSrqUJQyFU2Ddc8bb6u8SJfBU1nyCkNMgfA1zAyKnSBrzZWyyNzAm9oBV36qjC1Qjfcpq4FBnGr7foh5sLXppBwu2ES8U2nxdGrQzAbN47DKBoKJqPVxNh5tTuBdYjDGt7PcvZQjHQGNXXuhJctM5besZci2saGefCHzoZ87vSsFuKq6oXEsW512eadiNZWjHSdG9J4ToMEMK9WT66vGGLFdZszB3xhdFqEDnAMcpnoFUL5WN243aH6492jPC2Zjdi1BvHC1J8bUuvyihAKXF3WmFz7gJWmh6MrTEWNqb17K6tqbyXjFmfnS2RbAi8xBFj3sSsXkSs6TRTXAZD9DenYaQq4RMa2Kqh6VKGvkXAjVHKcPh9Ncpt6rU9ZYttNHbDJFgahwB8KisVBK8FBpG" + self.test_did_peer_4_short_b = ( + "did:peer:4zQmQ4dEtoGcivpiH6gtWwhWJY2ENVWuZifb62uzR76HGPPw" + ) + self.test_conn_record = ConnRecord( my_did=self.test_did, their_did=self.test_target_did, @@ -39,9 +48,7 @@ async def test_get_enums(self): assert ConnRecord.Role.get("Larry") is None assert ConnRecord.State.get("a suffusion of yellow") is None - assert ( - ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER - ) + assert ConnRecord.Role.get(ConnRecord.Role.REQUESTER) is ConnRecord.Role.REQUESTER assert ( ConnRecord.State.get(ConnRecord.State.RESPONSE) is ConnRecord.State.RESPONSE @@ -133,6 +140,71 @@ async def test_retrieve_by_did(self): ) assert result == record + async def test_retrieve_by_did_peer_4_by_long(self): + record = ConnRecord( + my_did=self.test_did, + their_did=self.test_did_peer_4_a, + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc23, + ) + rec_id = await record.save(self.session) + result = await ConnRecord.retrieve_by_did_peer_4( + session=self.session, + my_did=self.test_did, + their_did_long=self.test_did_peer_4_a, + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + assert result == record + + async def test_retrieve_by_did_peer_4_by_short(self): + record = ConnRecord( + my_did=self.test_did, + their_did=self.test_did_peer_4_short_b, + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc23, + ) + await record.save(self.session) + result = await ConnRecord.retrieve_by_did_peer_4( + session=self.session, + my_did=self.test_did, + their_did_short=self.test_did_peer_4_short_b, + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + assert result == record + + async def test_retrieve_by_did_peer_4_by_either(self): + record_short = ConnRecord( + my_did=self.test_did, + their_did=self.test_did_peer_4_short_a, + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc23, + ) + await record_short.save(self.session) + record_long = ConnRecord( + my_did=self.test_did, + their_did=self.test_did_peer_4_b, + their_role=ConnRecord.Role.RESPONDER.rfc23, + state=ConnRecord.State.COMPLETED.rfc23, + ) + await record_long.save(self.session) + + result = await ConnRecord.retrieve_by_did_peer_4( + session=self.session, + my_did=self.test_did, + their_did_short=self.test_did_peer_4_short_a, + their_did_long=self.test_did_peer_4_a, + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + assert result == record_short + result = await ConnRecord.retrieve_by_did_peer_4( + session=self.session, + my_did=self.test_did, + their_did_short=self.test_did_peer_4_short_b, + their_did_long=self.test_did_peer_4_b, + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + assert result == record_long + async def test_from_storage_with_initiator_old(self): record = ConnRecord(my_did=self.test_did, state=ConnRecord.State.COMPLETED) ser = record.serialize() @@ -300,9 +372,7 @@ async def test_attach_retrieve_request(self): connection_id = await record.save(self.session) req = ConnectionRequest( - connection=ConnectionDetail( - did=self.test_did, did_doc=DIDDoc(self.test_did) - ), + connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)), label="abc123", ) await record.attach_request(self.session, req) @@ -317,9 +387,7 @@ async def test_attach_request_abstain_on_alien_deco(self): connection_id = await record.save(self.session) req = ConnectionRequest( - connection=ConnectionDetail( - did=self.test_did, did_doc=DIDDoc(self.test_did) - ), + connection=ConnectionDetail(did=self.test_did, did_doc=DIDDoc(self.test_did)), label="abc123", ) ser = req.serialize() diff --git a/aries_cloudagent/core/profile.py b/aries_cloudagent/core/profile.py index 7b2b2f50da..c5bd599acb 100644 --- a/aries_cloudagent/core/profile.py +++ b/aries_cloudagent/core/profile.py @@ -1,8 +1,9 @@ """Classes for managing profile information within a request context.""" -import logging from abc import ABC, abstractmethod +import logging from typing import Any, Mapping, Optional, Type +from weakref import ref from ..config.base import InjectionError from ..config.injection_context import InjectionContext @@ -30,10 +31,13 @@ def __init__( created: bool = False, ): """Initialize a base profile.""" - self._context = context or InjectionContext() self._created = created self._name = name or Profile.DEFAULT_NAME + context = context or InjectionContext() + self._context = context.start_scope() + self._context.injector.bind_instance(Profile, ref(self)) + @property def backend(self) -> str: """Accessor for the backend implementation name.""" @@ -159,10 +163,12 @@ def __init__( self._active = False self._awaited = False self._entered = 0 - self._context = (context or profile.context).start_scope("session", settings) + self._context = (context or profile.context).start_scope(settings) self._profile = profile self._events = [] + self._context.injector.bind_instance(ProfileSession, ref(self)) + async def _setup(self): """Create the underlying session or transaction.""" diff --git a/aries_cloudagent/messaging/request_context.py b/aries_cloudagent/messaging/request_context.py index c96fb2e334..171085c4bc 100644 --- a/aries_cloudagent/messaging/request_context.py +++ b/aries_cloudagent/messaging/request_context.py @@ -26,13 +26,13 @@ def __init__( self, profile: Profile, *, - context: InjectionContext = None, - settings: Mapping[str, object] = None + context: Optional[InjectionContext] = None, + settings: Optional[Mapping[str, object]] = None ): """Initialize an instance of RequestContext.""" self._connection_ready = False self._connection_record = None - self._context = (context or profile.context).start_scope("request", settings) + self._context = (context or profile.context).start_scope(settings) self._message = None self._message_receipt = None self._profile = profile diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py index 30ddd66fa3..b7cb04e35d 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..messages.menu import Menu from ..util import save_connection_menu @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("MenuHandler called with context %s", context) assert isinstance(context.message, Menu) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu: %s", context.message) await save_connection_menu( diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py index 905e039ce8..4aef949495 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..base_service import BaseMenuService from ..messages.menu_request import MenuRequest @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("MenuRequestHandler called with context %s", context) assert isinstance(context.message, MenuRequest) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu request") service: BaseMenuService = context.inject_or(BaseMenuService) diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py index 5e38bc90f3..4f6a5b1387 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..base_service import BaseMenuService from ..messages.perform import Perform @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("PerformHandler called with context %s", context) assert isinstance(context.message, Perform) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu perform request") service: BaseMenuService = context.inject_or(BaseMenuService) diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py index 4034bf5ea0..392bc3f8cd 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import menu_handler as handler @@ -12,6 +12,7 @@ async def test_called(self): request_context = RequestContext.test_context() request_context.connection_record = mock.MagicMock() request_context.connection_record.connection_id = "dummy" + request_context.connection_ready = True handler.save_connection_menu = mock.CoroutineMock() responder = MockResponder() diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py index 30d97e65f4..63214fe409 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import menu_request_handler as handler @@ -18,6 +18,7 @@ async def test_called(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.MenuRequest() @@ -39,6 +40,7 @@ async def test_called_no_active_menu(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.MenuRequest() diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py index a8decf96ab..7af6672ee8 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import perform_handler as handler @@ -18,12 +18,11 @@ async def test_called(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.Perform() - self.menu_service.perform_menu_action = mock.CoroutineMock( - return_value="perform" - ) + self.menu_service.perform_menu_action = mock.CoroutineMock(return_value="perform") handler_inst = handler.PerformHandler() await handler_inst.handle(self.context, responder) @@ -41,6 +40,7 @@ async def test_called_no_active_menu(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.Perform() diff --git a/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py b/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py index 93bd91760f..286e62a63e 100644 --- a/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py +++ b/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..messages.basicmessage import BasicMessage @@ -22,6 +22,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("BasicMessageHandler called with context %s", context) assert isinstance(context.message, BasicMessage) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received basic message: %s", context.message.content) body = context.message.content diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/ack_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/ack_handler.py index 5d0942a5e9..4b8746e2bb 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/ack_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/ack_handler.py @@ -1,6 +1,6 @@ """Rotate ack handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder from ..manager import DIDRotateManager @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("RotateAckHandler called with context %s", context) assert isinstance(context.message, RotateAck) + if not context.connection_ready: + raise HandlerException("No connection established") + connection_record = context.connection_record ack = context.message diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py index 1fc7dff655..8e1cb7b102 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py @@ -1,6 +1,6 @@ """Rotate hangup handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder from ..manager import DIDRotateManager @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("HangupHandler called with context %s", context) assert isinstance(context.message, Hangup) + if not context.connection_ready: + raise HandlerException("No connection established") + connection_record = context.connection_record profile = context.profile diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py index 4d0bbc8a75..199952526d 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py @@ -1,6 +1,6 @@ """Rotate problem report handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder from ..manager import DIDRotateManager @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("ProblemReportHandler called with context %s", context) assert isinstance(context.message, RotateProblemReport) + if not context.connection_ready: + raise HandlerException("No connection established") + problem_report = context.message profile = context.profile diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/rotate_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/rotate_handler.py index a63b848fd2..e5fd8f0f83 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/rotate_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/rotate_handler.py @@ -1,6 +1,6 @@ """Rotate handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder from ..manager import DIDRotateManager @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("RotateHandler called with context %s", context) assert isinstance(context.message, Rotate) + if not context.connection_ready: + raise HandlerException("No connection established") + connection_record = context.connection_record rotate = context.message diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_ack_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_ack_handler.py index 4fc5806a35..2daf059db8 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_ack_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_ack_handler.py @@ -23,6 +23,7 @@ async def test_handle(self, MockDIDRotateManager, request_context): request_context.message = RotateAck() request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True handler = test_module.RotateAckHandler() responder = MockResponder() diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py index 764a5ee5b8..c945f5a427 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py @@ -23,6 +23,7 @@ async def test_handle(self, MockDIDRotateManager, request_context): request_context.message = Hangup() request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True handler = test_module.HangupHandler() responder = MockResponder() diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py index e461d6c8dc..06ecfe5e3e 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py @@ -27,6 +27,7 @@ async def test_handle(self, MockDIDRotateManager, request_context): request_context.message = RotateProblemReport() request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True handler = test_module.ProblemReportHandler() responder = MockResponder() diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_rotate_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_rotate_handler.py index 7eb9b70fdc..1ef31a30b1 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_rotate_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_rotate_handler.py @@ -28,6 +28,7 @@ async def test_handle(self, MockDIDRotateManager, request_context): request_context.message = Rotate(**test_valid_rotate_request) request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True handler = test_module.RotateHandler() responder = MockResponder() diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 2b5f9b984e..703a37ba03 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -798,6 +798,8 @@ async def create_response( use_did_method = "did:peer:2" elif conn_rec.their_did and conn_rec.their_did.startswith("did:peer:4"): use_did_method = "did:peer:4" + elif conn_rec.their_did and conn_rec.their_did.startswith("did:peer:1"): + use_did_method = "did:peer:4" else: use_did_method = None diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index f3c5d6d969..f3e8fc806b 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -51,6 +51,7 @@ class TestConfig: test_target_did = "GbuDUYXaUZRfHD2jeDuQuP" test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" + test_did_peer_1 = "did:peer:1zQmNa1NAgFNxoPu5XN7NmUfHk2mF6MnnysiNVDd7X72oPvm" test_did_peer_2 = "did:peer:2.Vz6MkeobNdKHDnMXhob5GPWmpEyNx3r9j6gqiKYJQ9J2wEPvx.SeyJpZCI6IiNkaWRjb21tLTAiLCJ0IjoiZGlkLWNvbW11bmljYXRpb24iLCJwcmlvcml0eSI6MCwicmVjaXBpZW50S2V5cyI6WyIja2V5LTEiXSwiciI6W10sInMiOiJodHRwOi8vaG9zdC5kb2NrZXIuaW50ZXJuYWw6OTA3MCJ9" test_did_peer_4 = "did:peer:4zQmd8CpeFPci817KDsbSAKWcXAE2mjvCQSasRewvbSF54Bd:z2M1k7h4psgp4CmJcnQn2Ljp7Pz7ktsd7oBhMU3dWY5s4fhFNj17qcRTQ427C7QHNT6cQ7T3XfRh35Q2GhaNFZmWHVFq4vL7F8nm36PA9Y96DvdrUiRUaiCuXnBFrn1o7mxFZAx14JL4t8vUWpuDPwQuddVo1T8myRiVH7wdxuoYbsva5x6idEpCQydJdFjiHGCpNc2UtjzPQ8awSXkctGCnBmgkhrj5gto3D4i3EREXYq4Z8r2cWGBr2UzbSmnxW2BuYddFo9Yfm6mKjtJyLpF74ytqrF5xtf84MnGFg1hMBmh1xVx1JwjZ2BeMJs7mNS8DTZhKC7KH38EgqDtUZzfjhpjmmUfkXg2KFEA3EGbbVm1DPqQXayPYKAsYPS9AyKkcQ3fzWafLPP93UfNhtUPL8JW5pMcSV3P8v6j3vPXqnnGknNyBprD6YGUVtgLiAqDBDUF3LSxFQJCVYYtghMTv8WuSw9h1a1SRFrDQLGHE4UrkgoRvwaGWr64aM87T1eVGkP5Dt4L1AbboeK2ceLArPScrdYGTpi3BpTkLwZCdjdiFSfTy9okL1YNRARqUf2wm8DvkVGUU7u5nQA3ZMaXWJAewk6k1YUxKd7LvofGUK4YEDtoxN5vb6r1Q2godrGqaPkjfL3RoYPpDYymf9XhcgG8Kx3DZaA6cyTs24t45KxYAfeCw4wqUpCH9HbpD78TbEUr9PPAsJgXBvBj2VVsxnr7FKbK4KykGcg1W8M1JPz21Z4Y72LWgGQCmixovrkHktcTX1uNHjAvKBqVD5C7XmVfHgXCHj7djCh3vzLNuVLtEED8J1hhqsB1oCBGiuh3xXr7fZ9wUjJCQ1HYHqxLJKdYKtoCiPmgKM7etVftXkmTFETZmpM19aRyih3bao76LdpQtbw636r7a3qt8v4WfxsXJetSL8c7t24SqQBcAY89FBsbEnFNrQCMK3JEseKHVaU388ctvRD45uQfe5GndFxthj4iSDomk4uRFd1uRbywoP1tRuabHTDX42UxPjz" @@ -99,7 +100,7 @@ async def asyncSetUp(self): "debug.auto_accept_invites": True, "debug.auto_accept_requests": True, "multitenant.enabled": True, - "wallet.id": True, + "wallet.id": "test-wallet-id", }, bind={ BaseResponder: self.responder, @@ -1702,6 +1703,36 @@ async def test_create_response_inkind_peer_did_4(self): mock_create_did_peer_4.assert_called_once() assert response.did.startswith("did:peer:4") + async def test_create_response_peer_1_gets_peer_4(self): + # created did:peer:4 when receiving a did:peer:4, even if setting is False + conn_rec = ConnRecord( + connection_id="dummy", + their_did=TestConfig.test_did_peer_1, + state=ConnRecord.State.REQUEST.rfc23, + ) + + self.profile.context.update_settings({"emit_did_peer_4": False}) + + with mock.patch.object( + self.manager, "create_did_peer_4", mock.CoroutineMock() + ) as mock_create_did_peer_4, mock.patch.object( + test_module.ConnRecord, "retrieve_request", mock.CoroutineMock() + ) as mock_retrieve_req, mock.patch.object( + conn_rec, "save", mock.CoroutineMock() + ) as mock_save: + mock_create_did_peer_4.return_value = DIDInfo( + TestConfig.test_did_peer_4, + TestConfig.test_verkey, + None, + method=PEER4, + key_type=ED25519, + ) + response = await self.manager.create_response( + conn_rec, "http://10.20.30.40:5060/" + ) + mock_create_did_peer_4.assert_called_once() + assert response.did.startswith("did:peer:4") + async def test_create_response_bad_state(self): with self.assertRaises(DIDXManagerError): await self.manager.create_response( diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py index 7c5a313fc4..9fe504873b 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py @@ -3,10 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, - RequestContext, HandlerException, + RequestContext, ) - from ..manager import V10DiscoveryMgr from ..messages.disclose import Disclose @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("DiscloseHandler called with context %s", context) assert isinstance(context.message, Disclose) + if not context.connection_ready: raise HandlerException( "Received disclosures message from inactive connection" ) + profile = context.profile mgr = V10DiscoveryMgr(profile) await mgr.receive_disclose( diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py index 0336b01351..c0276d5272 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import V10DiscoveryMgr from ..messages.query import Query @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("QueryHandler called with context %s", context) assert isinstance(context.message, Query) + + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = V10DiscoveryMgr(profile) reply = await mgr.receive_query(context.message) diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py index 16d5b7345a..14854eb0b7 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py @@ -5,7 +5,6 @@ from ......core.protocol_registry import ProtocolRegistry from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from ...handlers.query_handler import QueryHandler from ...messages.disclose import Disclose from ...messages.query import Query @@ -30,6 +29,7 @@ async def test_query_all(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -50,6 +50,7 @@ async def test_query_all_disclose_list_settings(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -65,6 +66,7 @@ async def test_receive_query_process_disclosed(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() with mock.patch.object( diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py index adab14bf3e..6b9f4047ba 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py @@ -3,10 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, - RequestContext, HandlerException, + RequestContext, ) - from ..manager import V20DiscoveryMgr from ..messages.disclosures import Disclosures @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("DiscloseHandler called with context %s", context) assert isinstance(context.message, Disclosures) + if not context.connection_ready: raise HandlerException( "Received disclosures message from inactive connection" ) + profile = context.profile mgr = V20DiscoveryMgr(profile) await mgr.receive_disclose( diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py index e970ad0c6f..95c26d8866 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import V20DiscoveryMgr from ..messages.queries import Queries @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("QueryHandler called with context %s", context) assert isinstance(context.message, Queries) + + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = V20DiscoveryMgr(profile) reply = await mgr.receive_query(context.message) diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py index 9560fbd0d8..2ae1786d51 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py @@ -1,9 +1,11 @@ +from typing import Generator + import pytest from aries_cloudagent.tests import mock -from ......core.protocol_registry import ProtocolRegistry from ......core.goal_code_registry import GoalCodeRegistry +from ......core.protocol_registry import ProtocolRegistry from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......protocols.issue_credential.v1_0.controller import ( @@ -16,7 +18,6 @@ from ......protocols.present_proof.v1_0.message_types import ( CONTROLLERS as pres_proof_v1_controller, ) - from ...handlers.queries_handler import QueriesHandler from ...manager import V20DiscoveryMgr from ...messages.disclosures import Disclosures @@ -27,7 +28,7 @@ @pytest.fixture() -def request_context() -> RequestContext: +def request_context() -> Generator[RequestContext, None, None]: ctx = RequestContext.test_context() protocol_registry = ProtocolRegistry() goal_code_registry = GoalCodeRegistry() @@ -48,6 +49,7 @@ async def test_queries_all(self, request_context): queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -67,6 +69,7 @@ async def test_queries_protocol_goal_code_all(self, request_context): queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -105,6 +108,7 @@ async def test_queries_protocol_goal_code_all_disclose_list_settings( queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -129,6 +133,7 @@ async def test_receive_query_process_disclosed(self, request_context): queries_msg = Queries(queries=test_queries) queries_msg.assign_thread_id("test123") request_context.message = queries_msg + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() with mock.patch.object( diff --git a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/tests/test_transaction_job_to_send_handler.py b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/tests/test_transaction_job_to_send_handler.py index 8bb0867548..39019ba56d 100644 --- a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/tests/test_transaction_job_to_send_handler.py +++ b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/tests/test_transaction_job_to_send_handler.py @@ -26,7 +26,7 @@ async def test_called(self): await handler.handle(request_context, responder) mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with( - request_context.message, request_context.message_receipt + request_context.message, request_context.connection_record ) assert not responder.messages @@ -48,6 +48,6 @@ async def test_called_x(self): await handler.handle(request_context, responder) mock_tran_mgr.return_value.set_transaction_their_job.assert_called_once_with( - request_context.message, request_context.message_receipt + request_context.message, request_context.connection_record ) assert not responder.messages diff --git a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py index 35e1d50bed..9e466a909e 100644 --- a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py +++ b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import TransactionManager, TransactionManagerError from ..messages.transaction_job_to_send import TransactionJobToSend @@ -24,10 +24,14 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug(f"TransactionJobToSendHandler called with context {context}") assert isinstance(context.message, TransactionJobToSend) + if not context.connection_ready: + raise HandlerException("No connection established") + assert context.connection_record + mgr = TransactionManager(context.profile) try: await mgr.set_transaction_their_job( - context.message, context.message_receipt + context.message, context.connection_record ) except TransactionManagerError: self._logger.exception("Error receiving transaction jobs") diff --git a/aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py b/aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py index a96a2e8fe9..0917c60c67 100644 --- a/aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py +++ b/aries_cloudagent/protocols/endorse_transaction/v1_0/manager.py @@ -20,7 +20,6 @@ notify_revocation_reg_endorsed_event, ) from ....storage.error import StorageError, StorageNotFoundError -from ....transport.inbound.receipt import MessageReceipt from ....wallet.base import BaseWallet from ....wallet.util import ( notify_endorse_did_attrib_event, @@ -293,9 +292,7 @@ async def create_endorse_response( ) # we don't have an endorsed transaction so just return did meta-data ledger_response = { - "result": { - "txn": {"type": "1", "data": {"dest": meta_data["did"]}} - }, + "result": {"txn": {"type": "1", "data": {"dest": meta_data["did"]}}}, "meta_data": meta_data, } endorsed_msg = json.dumps(ledger_response) @@ -412,9 +409,10 @@ async def complete_transaction( # if we are the author, we need to write the endorsed ledger transaction ... # ... EXCEPT for DID transactions, which the endorser will write - if (not endorser) and ( - txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION - ): + if (not endorser) and (txn_goal_code != TransactionRecord.WRITE_DID_TRANSACTION): + ledger = self.profile.inject(BaseLedger) + if not ledger: + raise TransactionManagerError("No ledger available") if ( self._profile.context.settings.get_value("wallet.type") == "askar-anoncreds" @@ -752,20 +750,17 @@ async def set_transaction_my_job(self, record: ConnRecord, transaction_my_job: s return tx_job_to_send async def set_transaction_their_job( - self, tx_job_received: TransactionJobToSend, receipt: MessageReceipt + self, tx_job_received: TransactionJobToSend, connection: ConnRecord ): """Set transaction_their_job. Args: tx_job_received: The transaction job that is received from the other agent - receipt: The Message Receipt Object + connection: connection to set metadata on """ try: async with self._profile.session() as session: - connection = await ConnRecord.retrieve_by_did( - session, receipt.sender_did, receipt.recipient_did - ) value = await connection.metadata_get(session, "transaction_jobs") if value: value["transaction_their_job"] = tx_job_received.job @@ -871,9 +866,7 @@ async def endorsed_txn_post_processing( elif ledger_response["result"]["txn"]["type"] == "114": # revocation entry transaction rev_reg_id = ledger_response["result"]["txn"]["data"]["revocRegDefId"] - revoked = ledger_response["result"]["txn"]["data"]["value"].get( - "revoked", [] - ) + revoked = ledger_response["result"]["txn"]["data"]["value"].get("revoked", []) meta_data["context"]["rev_reg_id"] = rev_reg_id if is_anoncreds: await AnonCredsRevocation(self._profile).finish_revocation_list( diff --git a/aries_cloudagent/protocols/endorse_transaction/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/endorse_transaction/v1_0/tests/test_manager.py index f2cbb7aa2d..8bf244c27b 100644 --- a/aries_cloudagent/protocols/endorse_transaction/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/endorse_transaction/v1_0/tests/test_manager.py @@ -11,7 +11,6 @@ from .....cache.in_memory import InMemoryCache from .....connections.models.conn_record import ConnRecord from .....ledger.base import BaseLedger -from .....storage.error import StorageNotFoundError from .....tests import mock from .....wallet.base import BaseWallet from .....wallet.did_method import SOV, DIDMethods @@ -156,9 +155,7 @@ async def test_create_record(self): transaction_record.messages_attach[0]["data"]["json"] == self.test_messages_attach ) - assert ( - transaction_record.state == TransactionRecord.STATE_TRANSACTION_CREATED - ) + assert transaction_record.state == TransactionRecord.STATE_TRANSACTION_CREATED async def test_txn_rec_retrieve_by_connection_and_thread_caching(self): async with self.profile.session() as sesn: @@ -603,8 +600,7 @@ async def test_create_refuse_response(self): assert transaction_record.state == TransactionRecord.STATE_TRANSACTION_REFUSED assert ( - refused_transaction_response.transaction_id - == self.test_author_transaction_id + refused_transaction_response.transaction_id == self.test_author_transaction_id ) assert refused_transaction_response.thread_id == transaction_record._id assert refused_transaction_response.signature_response == { @@ -640,9 +636,7 @@ async def test_receive_refuse_response(self): mock_response.endorser_did = self.test_refuser_did with mock.patch.object(TransactionRecord, "save", autospec=True) as save_record: - transaction_record = await self.manager.receive_refuse_response( - mock_response - ) + transaction_record = await self.manager.receive_refuse_response(mock_response) save_record.assert_called_once() assert transaction_record._type == TransactionRecord.SIGNATURE_RESPONSE @@ -688,9 +682,7 @@ async def test_cancel_transaction(self): assert transaction_record.state == TransactionRecord.STATE_TRANSACTION_CANCELLED - assert ( - cancelled_transaction_response.thread_id == self.test_author_transaction_id - ) + assert cancelled_transaction_response.thread_id == self.test_author_transaction_id assert ( cancelled_transaction_response.state == TransactionRecord.STATE_TRANSACTION_CANCELLED @@ -809,35 +801,17 @@ async def test_set_transaction_my_job(self): async def test_set_transaction_their_job(self): mock_job = mock.MagicMock() - mock_receipt = mock.MagicMock() - - with mock.patch.object( - ConnRecord, "retrieve_by_did", mock.CoroutineMock() - ) as mock_retrieve: - mock_retrieve.return_value = mock.MagicMock( - metadata_get=mock.CoroutineMock( - side_effect=[ - None, - {"meta": "data"}, - ] - ), - metadata_set=mock.CoroutineMock(), - ) - - for i in range(2): - await self.manager.set_transaction_their_job(mock_job, mock_receipt) - - async def test_set_transaction_their_job_conn_not_found(self): - mock_job = mock.MagicMock() - mock_receipt = mock.MagicMock() - - with mock.patch.object( - ConnRecord, "retrieve_by_did", mock.CoroutineMock() - ) as mock_retrieve: - mock_retrieve.side_effect = StorageNotFoundError() + mock_conn = mock.MagicMock() + mock_conn.metadata_get = mock.CoroutineMock( + side_effect=[ + None, + {"meta": "data"}, + ] + ) + mock_conn.metadata_set = mock.CoroutineMock() - with self.assertRaises(TransactionManagerError): - await self.manager.set_transaction_their_job(mock_job, mock_receipt) + for i in range(2): + await self.manager.set_transaction_their_job(mock_job, mock_conn) @mock.patch.object(AnonCredsIssuer, "finish_schema") @mock.patch.object(AnonCredsIssuer, "finish_cred_def") diff --git a/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py b/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py index e5f1fedc77..71ba1b29a3 100644 --- a/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py +++ b/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py @@ -1,10 +1,9 @@ """Generic ack message handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder -from .....utils.tracing import trace_event, get_timer - +from .....utils.tracing import get_timer, trace_event from ..messages.ack import V10Ack @@ -22,6 +21,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("V20PresAckHandler called with context %s", context) assert isinstance(context.message, V10Ack) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info( "Received v1.0 notification ack message: %s", context.message.serialize(as_string=True), diff --git a/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py b/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py index 2389b5aade..59b9effd43 100644 --- a/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py +++ b/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py @@ -1,12 +1,9 @@ -from unittest import mock -from unittest import IsolatedAsyncioTestCase +from unittest import IsolatedAsyncioTestCase, mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......transport.inbound.receipt import MessageReceipt - from ...messages.ack import V10Ack - from .. import ack_handler as test_module @@ -15,6 +12,7 @@ async def test_called(self): request_context = RequestContext.test_context() request_context.message_receipt = MessageReceipt() request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True request_context.message = V10Ack(status="OK") handler = test_module.V10AckHandler() diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py index a312a2af1f..d5851df485 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py @@ -1,9 +1,8 @@ """Handshake Reuse Message Handler under RFC 0434.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..manager import OutOfBandManager, OutOfBandManagerError from ..messages.reuse import HandshakeReuse @@ -18,11 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): context: Request context responder: Responder callback """ - self._logger.debug( - f"HandshakeReuseMessageHandler called with context {context}" - ) + self._logger.debug(f"HandshakeReuseMessageHandler called with context {context}") assert isinstance(context.message, HandshakeReuse) + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = OutOfBandManager(profile) try: diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py index 2600c1db64..9a958b11ed 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py @@ -1,5 +1,7 @@ """Test Reuse Message Handler.""" +from typing import AsyncGenerator, Generator + import pytest from aries_cloudagent.tests import mock @@ -9,7 +11,6 @@ from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......transport.inbound.receipt import MessageReceipt - from ...handlers import reuse_handler as test_module from ...manager import OutOfBandManagerError from ...messages.reuse import HandshakeReuse @@ -17,14 +18,14 @@ @pytest.fixture() -async def request_context() -> RequestContext: +def request_context() -> Generator[RequestContext, None, None]: ctx = RequestContext.test_context() ctx.message_receipt = MessageReceipt() yield ctx @pytest.fixture() -async def session(request_context) -> ProfileSession: +async def session(request_context) -> AsyncGenerator[ProfileSession, None]: yield await request_context.session() @@ -36,6 +37,7 @@ async def test_called(self, mock_oob_mgr, request_context): request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() await handler.handle(request_context, responder) mock_oob_mgr.return_value.receive_reuse_message.assert_called_once_with( @@ -53,6 +55,7 @@ async def test_reuse_accepted(self, mock_oob_mgr, request_context): request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() await handler.handle(request_context, responder) mock_oob_mgr.return_value.receive_reuse_message.assert_called_once_with( @@ -73,6 +76,7 @@ async def test_exception( request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() with caplog.at_level("ERROR"): await handler.handle(request_context, responder) diff --git a/aries_cloudagent/protocols/present_proof/dif/pres_exch.py b/aries_cloudagent/protocols/present_proof/dif/pres_exch.py index cc7ae20645..de3def8b78 100644 --- a/aries_cloudagent/protocols/present_proof/dif/pres_exch.py +++ b/aries_cloudagent/protocols/present_proof/dif/pres_exch.py @@ -824,12 +824,10 @@ class Meta: id = fields.Str( required=False, - validate=UUID4_VALIDATE, metadata={"description": "ID", "example": UUID4_EXAMPLE}, ) definition_id = fields.Str( required=False, - validate=UUID4_VALIDATE, metadata={"description": "DefinitionID", "example": UUID4_EXAMPLE}, ) descriptor_maps = fields.List( diff --git a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py index af8e6f00de..82726993bd 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py @@ -5,7 +5,6 @@ from typing import Mapping, Optional, Sequence, Tuple from uuid import uuid4 -from marshmallow import RAISE from ......messaging.base_handler import BaseResponder from ......messaging.decorators.attach_decorator import AttachDecorator @@ -75,7 +74,7 @@ def validate_fields(cls, message_type: str, attachment_data: Mapping): Schema = mapping[message_type] # Validate, throw if not valid - Schema(unknown=RAISE).load(attachment_data) + Schema().load(attachment_data) def get_format_identifier(self, message_type: str) -> str: """Get attachment format identifier for format and message combination. diff --git a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py index 32697992da..9913af525b 100644 --- a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py @@ -1,9 +1,8 @@ """Handler for revoke message.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..messages.revoke import Revoke @@ -16,6 +15,10 @@ class RevokeHandler(BaseHandler): async def handle(self, context: RequestContext, responder: BaseResponder): """Handle revoke message.""" assert isinstance(context.message, Revoke) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.debug( "Received notification of revocation for cred issued in thread %s " "with comment: %s", diff --git a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py index 342aaf782d..f3b897c413 100644 --- a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py @@ -1,12 +1,14 @@ """Test RevokeHandler.""" +from typing import Generator + import pytest from ......core.event_bus import EventBus, MockEventBus from ......core.in_memory import InMemoryProfile from ......core.profile import Profile from ......messaging.request_context import RequestContext -from ......messaging.responder import MockResponder, BaseResponder +from ......messaging.responder import BaseResponder, MockResponder from ...messages.revoke import Revoke from ..revoke_handler import RevokeHandler @@ -32,7 +34,7 @@ def message(): @pytest.fixture -def context(profile: Profile, message: Revoke): +def context(profile: Profile, message: Revoke) -> Generator[RequestContext, None, None]: request_context = RequestContext(profile) request_context.message = message yield request_context @@ -42,6 +44,7 @@ def context(profile: Profile, message: Revoke): async def test_handle( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): + context.connection_ready = True await RevokeHandler().handle(context, responder) assert event_bus.events [(_, received)] = event_bus.events @@ -55,6 +58,7 @@ async def test_handle_monitor( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): context.settings["revocation.monitor_notification"] = True + context.connection_ready = True await RevokeHandler().handle(context, responder) [(_, webhook), (_, received)] = event_bus.events diff --git a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py index 4332440761..f6969fc357 100644 --- a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py @@ -1,9 +1,8 @@ """Handler for revoke message.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..messages.revoke import Revoke @@ -16,6 +15,10 @@ class RevokeHandler(BaseHandler): async def handle(self, context: RequestContext, responder: BaseResponder): """Handle revoke message.""" assert isinstance(context.message, Revoke) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.debug( "Received notification of revocation for %s cred %s with comment: %s", context.message.revocation_format, diff --git a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py index ea523a8a26..93ef0b430a 100644 --- a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py @@ -1,12 +1,14 @@ """Test RevokeHandler.""" +from typing import Generator + import pytest from ......core.event_bus import EventBus, MockEventBus from ......core.in_memory import InMemoryProfile from ......core.profile import Profile from ......messaging.request_context import RequestContext -from ......messaging.responder import MockResponder, BaseResponder +from ......messaging.responder import BaseResponder, MockResponder from ...messages.revoke import Revoke from ..revoke_handler import RevokeHandler @@ -36,7 +38,7 @@ def message(): @pytest.fixture -def context(profile: Profile, message: Revoke): +def context(profile: Profile, message: Revoke) -> Generator[RequestContext, None, None]: request_context = RequestContext(profile) request_context.message = message yield request_context @@ -46,6 +48,7 @@ def context(profile: Profile, message: Revoke): async def test_handle( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): + context.connection_ready = True await RevokeHandler().handle(context, responder) assert event_bus.events [(_, received)] = event_bus.events @@ -60,6 +63,7 @@ async def test_handle_monitor( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): context.settings["revocation.monitor_notification"] = True + context.connection_ready = True await RevokeHandler().handle(context, responder) [(_, webhook), (_, received)] = event_bus.events diff --git a/aries_cloudagent/vc/vc_ld/manager.py b/aries_cloudagent/vc/vc_ld/manager.py index de3f578158..25e8827619 100644 --- a/aries_cloudagent/vc/vc_ld/manager.py +++ b/aries_cloudagent/vc/vc_ld/manager.py @@ -57,9 +57,7 @@ Ed25519Signature2018: ED25519, Ed25519Signature2020: ED25519, } -PROOF_KEY_TYPE_MAPPING = cast( - Dict[ProofTypes, KeyType], SIGNATURE_SUITE_KEY_TYPE_MAPPING -) +PROOF_KEY_TYPE_MAPPING = cast(Dict[ProofTypes, KeyType], SIGNATURE_SUITE_KEY_TYPE_MAPPING) # We only want to add bbs suites to supported if the module is installed diff --git a/aries_cloudagent/vc/vc_ld/models/linked_data_proof.py b/aries_cloudagent/vc/vc_ld/models/linked_data_proof.py index 40e5a2b7db..6787e82be7 100644 --- a/aries_cloudagent/vc/vc_ld/models/linked_data_proof.py +++ b/aries_cloudagent/vc/vc_ld/models/linked_data_proof.py @@ -105,9 +105,6 @@ class Meta: domain = fields.Str( required=False, - # TODO the domain can be more than a Uri, provide a less restrictive validation - # https://www.w3.org/TR/vc-data-integrity/#defn-domain - validate=Uri(), metadata={ "description": ( "A string value specifying the restricted domain of the signature."