Skip to content

Commit

Permalink
Check connection is ready in all connection required handlers
Browse files Browse the repository at this point in the history
Signed-off-by: jamshale <[email protected]>
  • Loading branch information
jamshale committed Jul 10, 2024
1 parent 811ff3b commit 531162a
Show file tree
Hide file tree
Showing 30 changed files with 119 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..messages.menu import Menu
from ..util import save_connection_menu

Expand All @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("MenuHandler called with context %s", context)
assert isinstance(context.message, Menu)

if not context.connection_ready:
raise HandlerException("No connection established")

self._logger.info("Received action menu: %s", context.message)

await save_connection_menu(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..base_service import BaseMenuService
from ..messages.menu_request import MenuRequest

Expand All @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("MenuRequestHandler called with context %s", context)
assert isinstance(context.message, MenuRequest)

if not context.connection_ready:
raise HandlerException("No connection established")

self._logger.info("Received action menu request")

service: BaseMenuService = context.inject_or(BaseMenuService)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..base_service import BaseMenuService
from ..messages.perform import Perform

Expand All @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("PerformHandler called with context %s", context)
assert isinstance(context.message, Perform)

if not context.connection_ready:
raise HandlerException("No connection established")

self._logger.info("Received action menu perform request")

service: BaseMenuService = context.inject_or(BaseMenuService)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest import IsolatedAsyncioTestCase

from aries_cloudagent.tests import mock

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder

from .. import menu_handler as handler


Expand All @@ -12,6 +12,7 @@ async def test_called(self):
request_context = RequestContext.test_context()
request_context.connection_record = mock.MagicMock()
request_context.connection_record.connection_id = "dummy"
request_context.connection_ready = True

handler.save_connection_menu = mock.CoroutineMock()
responder = MockResponder()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest import IsolatedAsyncioTestCase

from aries_cloudagent.tests import mock

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder

from .. import menu_request_handler as handler


Expand All @@ -18,6 +18,7 @@ async def test_called(self):

self.context.connection_record = mock.MagicMock()
self.context.connection_record.connection_id = "dummy"
self.context.connection_ready = True

responder = MockResponder()
self.context.message = handler.MenuRequest()
Expand All @@ -39,6 +40,7 @@ async def test_called_no_active_menu(self):

self.context.connection_record = mock.MagicMock()
self.context.connection_record.connection_id = "dummy"
self.context.connection_ready = True

responder = MockResponder()
self.context.message = handler.MenuRequest()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest import IsolatedAsyncioTestCase

from aries_cloudagent.tests import mock

from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder

from .. import perform_handler as handler


Expand All @@ -18,12 +18,11 @@ async def test_called(self):

self.context.connection_record = mock.MagicMock()
self.context.connection_record.connection_id = "dummy"
self.context.connection_ready = True

responder = MockResponder()
self.context.message = handler.Perform()
self.menu_service.perform_menu_action = mock.CoroutineMock(
return_value="perform"
)
self.menu_service.perform_menu_action = mock.CoroutineMock(return_value="perform")

handler_inst = handler.PerformHandler()
await handler_inst.handle(self.context, responder)
Expand All @@ -41,6 +40,7 @@ async def test_called_no_active_menu(self):

self.context.connection_record = mock.MagicMock()
self.context.connection_record.connection_id = "dummy"
self.context.connection_ready = True

responder = MockResponder()
self.context.message = handler.Perform()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..messages.basicmessage import BasicMessage


Expand All @@ -22,6 +22,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("BasicMessageHandler called with context %s", context)
assert isinstance(context.message, BasicMessage)

if not context.connection_ready:
raise HandlerException("No connection established")

self._logger.info("Received basic message: %s", context.message.content)

body = context.message.content
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Rotate ack handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.base_handler import BaseHandler, HandlerException
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
Expand All @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("RotateAckHandler called with context %s", context)
assert isinstance(context.message, RotateAck)

if not context.connection_ready:
raise HandlerException("No connection established")

connection_record = context.connection_record
ack = context.message

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Rotate hangup handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.base_handler import BaseHandler, HandlerException
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
Expand All @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("HangupHandler called with context %s", context)
assert isinstance(context.message, Hangup)

if not context.connection_ready:
raise HandlerException("No connection established")

connection_record = context.connection_record

profile = context.profile
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Rotate problem report handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.base_handler import BaseHandler, HandlerException
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
Expand All @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("ProblemReportHandler called with context %s", context)
assert isinstance(context.message, RotateProblemReport)

if not context.connection_ready:
raise HandlerException("No connection established")

problem_report = context.message

profile = context.profile
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Rotate handler."""

from .....messaging.base_handler import BaseHandler
from .....messaging.base_handler import BaseHandler, HandlerException
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder
from ..manager import DIDRotateManager
Expand All @@ -20,6 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("RotateHandler called with context %s", context)
assert isinstance(context.message, Rotate)

if not context.connection_ready:
raise HandlerException("No connection established")

connection_record = context.connection_record
rotate = context.message

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def test_handle(self, MockDIDRotateManager, request_context):

request_context.message = RotateAck()
request_context.connection_record = mock.MagicMock()
request_context.connection_ready = True

handler = test_module.RotateAckHandler()
responder = MockResponder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def test_handle(self, MockDIDRotateManager, request_context):

request_context.message = Hangup()
request_context.connection_record = mock.MagicMock()
request_context.connection_ready = True

handler = test_module.HangupHandler()
responder = MockResponder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def test_handle(self, MockDIDRotateManager, request_context):

request_context.message = RotateProblemReport()
request_context.connection_record = mock.MagicMock()
request_context.connection_ready = True

handler = test_module.ProblemReportHandler()
responder = MockResponder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def test_handle(self, MockDIDRotateManager, request_context):

request_context.message = Rotate(**test_valid_rotate_request)
request_context.connection_record = mock.MagicMock()
request_context.connection_ready = True

handler = test_module.RotateHandler()
responder = MockResponder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
RequestContext,
HandlerException,
RequestContext,
)

from ..manager import V10DiscoveryMgr
from ..messages.disclose import Disclose

Expand All @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"""Message handler implementation."""
self._logger.debug("DiscloseHandler called with context %s", context)
assert isinstance(context.message, Disclose)

if not context.connection_ready:
raise HandlerException(
"Received disclosures message from inactive connection"
)

profile = context.profile
mgr = V10DiscoveryMgr(profile)
await mgr.receive_disclose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..manager import V10DiscoveryMgr
from ..messages.query import Query

Expand All @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"""Message handler implementation."""
self._logger.debug("QueryHandler called with context %s", context)
assert isinstance(context.message, Query)

if not context.connection_ready:
raise HandlerException("No connection established")

profile = context.profile
mgr = V10DiscoveryMgr(profile)
reply = await mgr.receive_query(context.message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ......core.protocol_registry import ProtocolRegistry
from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder

from ...handlers.query_handler import QueryHandler
from ...messages.disclose import Disclose
from ...messages.query import Query
Expand All @@ -30,6 +29,7 @@ async def test_query_all(self, request_context):
query_msg = Query(query="*")
query_msg.assign_thread_id("test123")
request_context.message = query_msg
request_context.connection_ready = True
handler = QueryHandler()
responder = MockResponder()
await handler.handle(request_context, responder)
Expand All @@ -50,6 +50,7 @@ async def test_query_all_disclose_list_settings(self, request_context):
query_msg = Query(query="*")
query_msg.assign_thread_id("test123")
request_context.message = query_msg
request_context.connection_ready = True
handler = QueryHandler()
responder = MockResponder()
await handler.handle(request_context, responder)
Expand All @@ -65,6 +66,7 @@ async def test_receive_query_process_disclosed(self, request_context):
query_msg = Query(query="*")
query_msg.assign_thread_id("test123")
request_context.message = query_msg
request_context.connection_ready = True
handler = QueryHandler()
responder = MockResponder()
with mock.patch.object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
RequestContext,
HandlerException,
RequestContext,
)

from ..manager import V20DiscoveryMgr
from ..messages.disclosures import Disclosures

Expand All @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"""Message handler implementation."""
self._logger.debug("DiscloseHandler called with context %s", context)
assert isinstance(context.message, Disclosures)

if not context.connection_ready:
raise HandlerException(
"Received disclosures message from inactive connection"
)

profile = context.profile
mgr = V20DiscoveryMgr(profile)
await mgr.receive_disclose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .....messaging.base_handler import (
BaseHandler,
BaseResponder,
HandlerException,
RequestContext,
)

from ..manager import V20DiscoveryMgr
from ..messages.queries import Queries

Expand All @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
"""Message handler implementation."""
self._logger.debug("QueryHandler called with context %s", context)
assert isinstance(context.message, Queries)

if not context.connection_ready:
raise HandlerException("No connection established")

profile = context.profile
mgr = V20DiscoveryMgr(profile)
reply = await mgr.receive_query(context.message)
Expand Down
Loading

0 comments on commit 531162a

Please sign in to comment.