diff --git a/bittensor/axon.py b/bittensor/axon.py index 38093c2fde..959e009c3a 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -30,13 +30,15 @@ import threading import time import traceback +import typing import uuid from inspect import signature, Signature, Parameter -from typing import List, Optional, Tuple, Callable, Any, Dict +from typing import List, Optional, Tuple, Callable, Any, Dict, Awaitable import uvicorn from fastapi import FastAPI, APIRouter, Depends from fastapi.responses import JSONResponse +from fastapi.routing import serialize_response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response @@ -51,9 +53,8 @@ NotVerifiedException, BlacklistedException, PriorityException, - RunException, PostProcessException, - InternalServerError, + SynapseException, ) from bittensor.threadpool import PriorityThreadPoolExecutor @@ -376,7 +377,8 @@ def __init__( self.app.include_router(self.router) # Build ourselves as the middleware. - self.app.add_middleware(AxonMiddleware, axon=self) + self.middleware_cls = AxonMiddleware + self.app.add_middleware(self.middleware_cls, axon=self) # Attach default forward. def ping(r: bittensor.Synapse) -> bittensor.Synapse: @@ -465,89 +467,72 @@ def verify_custom(synapse: MyCustomSynapse): offered by this method allows developers to tailor the Axon's behavior to specific requirements and use cases. """ - - # Assert 'forward_fn' has exactly one argument forward_sig = signature(forward_fn) - assert ( - len(list(forward_sig.parameters)) == 1 - ), "The passed function must have exactly one argument" - - # Obtain the class of the first argument of 'forward_fn' - request_class = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation + try: + first_param = next(iter(forward_sig.parameters.values())) + except StopIteration: + raise ValueError( + "The forward_fn first argument must be a subclass of bittensor.Synapse, but it has no arguments" + ) - # Assert that the first argument of 'forward_fn' is a subclass of 'bittensor.Synapse' + param_class = first_param.annotation assert issubclass( - request_class, bittensor.Synapse - ), "The argument of forward_fn must inherit from bittensor.Synapse" + param_class, bittensor.Synapse + ), "The first argument of forward_fn must inherit from bittensor.Synapse" + request_name = param_class.__name__ + + async def endpoint(*args, **kwargs): + start_time = time.time() + response_synapse = forward_fn(*args, **kwargs) + if isinstance(response_synapse, Awaitable): + response_synapse = await response_synapse + return await self.middleware_cls.synapse_to_response( + synapse=response_synapse, start_time=start_time + ) - # Obtain the class name of the first argument of 'forward_fn' - request_name = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation.__name__ + # replace the endpoint signature, but set return annotation to JSONResponse + endpoint.__signature__ = Signature( # type: ignore + parameters=list(forward_sig.parameters.values()), + return_annotation=JSONResponse, + ) # Add the endpoint to the router, making it available on both GET and POST methods self.router.add_api_route( f"/{request_name}", - forward_fn, + endpoint, methods=["GET", "POST"], dependencies=[Depends(self.verify_body_integrity)], ) self.app.include_router(self.router) - # Expected signatures for 'blacklist_fn', 'priority_fn' and 'verify_fn' - blacklist_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=Tuple[bool, str], - ) - priority_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=float, - ) - verify_sig = Signature( - [ - Parameter( - "synapse", - Parameter.POSITIONAL_OR_KEYWORD, - annotation=forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation, - ) - ], - return_annotation=None, - ) - # Check the signature of blacklist_fn, priority_fn and verify_fn if they are provided + expected_params = [ + Parameter( + "synapse", + Parameter.POSITIONAL_OR_KEYWORD, + annotation=forward_sig.parameters[ + list(forward_sig.parameters)[0] + ].annotation, + ) + ] if blacklist_fn: + blacklist_sig = Signature( + expected_params, return_annotation=Tuple[bool, str] + ) assert ( signature(blacklist_fn) == blacklist_sig ), "The blacklist_fn function must have the signature: blacklist( synapse: {} ) -> Tuple[bool, str]".format( request_name ) if priority_fn: + priority_sig = Signature(expected_params, return_annotation=float) assert ( signature(priority_fn) == priority_sig ), "The priority_fn function must have the signature: priority( synapse: {} ) -> float".format( request_name ) if verify_fn: + verify_sig = Signature(expected_params, return_annotation=None) assert ( signature(verify_fn) == verify_sig ), "The verify_fn function must have the signature: verify( synapse: {} ) -> None".format( @@ -555,9 +540,7 @@ def verify_custom(synapse: MyCustomSynapse): ) # Store functions in appropriate attribute dictionaries - self.forward_class_types[request_name] = forward_sig.parameters[ - list(forward_sig.parameters)[0] - ].annotation + self.forward_class_types[request_name] = param_class self.blacklist_fns[request_name] = blacklist_fn self.priority_fns[request_name] = priority_fn self.verify_fns[request_name] = ( @@ -933,7 +916,7 @@ async def default_verify(self, synapse: bittensor.Synapse): # Success self.nonces[endpoint_key] = synapse.dendrite.nonce # type: ignore else: - raise SynapseDendriteNoneException() + raise SynapseDendriteNoneException(synapse=synapse) def create_error_response(synapse: bittensor.Synapse): @@ -954,28 +937,53 @@ def create_error_response(synapse: bittensor.Synapse): def log_and_handle_error( synapse: bittensor.Synapse, exception: Exception, - status_code: int, - start_time: float, + status_code: typing.Optional[int] = None, + start_time: typing.Optional[float] = None, ): + if isinstance(exception, SynapseException): + synapse = exception.synapse or synapse # Display the traceback for user clarity. bittensor.logging.trace(f"Forward exception: {traceback.format_exc()}") + if synapse.axon is None: + synapse.axon = bittensor.TerminalInfo() + # Set the status code of the synapse to the given status code. + error_id = str(uuid.uuid4()) error_type = exception.__class__.__name__ - error_message = str(exception) - detailed_error_message = f"{error_type}: {error_message}" # Log the detailed error message for internal use - bittensor.logging.error(detailed_error_message) + bittensor.logging.error(f"{error_type}#{error_id}: {exception}") + + if not status_code and synapse.axon.status_code != 100: + status_code = synapse.axon.status_code + status_message = synapse.axon.status_message + if isinstance(exception, SynapseException): + if not status_code: + if isinstance(exception, PriorityException): + status_code = 503 + elif isinstance(exception, UnknownSynapseError): + status_code = 404 + elif isinstance(exception, BlacklistedException): + status_code = 403 + elif isinstance(exception, NotVerifiedException): + status_code = 401 + elif isinstance(exception, (InvalidRequestNameError, SynapseParsingError)): + status_code = 400 + else: + status_code = 500 + status_message = status_message or str(exception) + else: + status_code = status_code or 500 + status_message = status_message or f"Internal Server Error #{error_id}" - if synapse.axon is None: - raise SynapseParsingError(detailed_error_message) # Set a user-friendly error message synapse.axon.status_code = status_code - synapse.axon.status_message = error_message + synapse.axon.status_message = status_message - # Calculate the processing time by subtracting the start time from the current time. - synapse.axon.process_time = str(time.time() - start_time) # type: ignore + if start_time: + # Calculate the processing time by subtracting the start time from the current time. + synapse.axon.process_time = str(time.time() - start_time) # type: ignore return synapse @@ -1045,7 +1053,14 @@ async def dispatch( try: # Set up the synapse from its headers. - synapse: bittensor.Synapse = await self.preprocess(request) + try: + synapse: bittensor.Synapse = await self.preprocess(request) + except Exception as exc: + if isinstance(exc, SynapseException) and exc.synapse is not None: + synapse = exc.synapse + else: + synapse = bittensor.Synapse() + raise # Logs the start of the request processing if synapse.dendrite is not None: @@ -1069,56 +1084,22 @@ async def dispatch( # Call the run function response = await self.run(synapse, call_next, request) - # Call the postprocess function - response = await self.postprocess(synapse, response, start_time) - # Handle errors related to preprocess. except InvalidRequestNameError as e: - if "synapse" not in locals(): - synapse: bittensor.Synapse = bittensor.Synapse() # type: ignore - log_and_handle_error(synapse, e, 400, start_time) - response = create_error_response(synapse) - - except SynapseParsingError as e: - if "synapse" not in locals(): - synapse = bittensor.Synapse() - log_and_handle_error(synapse, e, 400, start_time) - response = create_error_response(synapse) - - except UnknownSynapseError as e: - if "synapse" not in locals(): - synapse = bittensor.Synapse() - log_and_handle_error(synapse, e, 404, start_time) - response = create_error_response(synapse) - - # Handle errors related to verify. - except NotVerifiedException as e: - log_and_handle_error(synapse, e, 401, start_time) - response = create_error_response(synapse) - - # Handle errors related to blacklist. - except BlacklistedException as e: - log_and_handle_error(synapse, e, 403, start_time) - response = create_error_response(synapse) - - # Handle errors related to priority. - except PriorityException as e: - log_and_handle_error(synapse, e, 503, start_time) + if synapse.axon is None: + synapse.axon = bittensor.TerminalInfo() + synapse.axon.status_code = 400 + synapse.axon.status_message = str(e) + synapse = log_and_handle_error(synapse, e, start_time=start_time) response = create_error_response(synapse) - - # Handle errors related to run. - except RunException as e: - log_and_handle_error(synapse, e, 500, start_time) - response = create_error_response(synapse) - - # Handle errors related to postprocess. - except PostProcessException as e: - log_and_handle_error(synapse, e, 500, start_time) + except SynapseException as e: + synapse = e.synapse or synapse + synapse = log_and_handle_error(synapse, e, start_time=start_time) response = create_error_response(synapse) # Handle all other errors. except Exception as e: - log_and_handle_error(synapse, InternalServerError(str(e)), 500, start_time) + synapse = log_and_handle_error(synapse, e, start_time=start_time) response = create_error_response(synapse) # Logs the end of request processing and returns the response @@ -1193,8 +1174,7 @@ async def preprocess(self, request: Request) -> bittensor.Synapse: "version": str(bittensor.__version_as_int__), "uuid": str(self.axon.uuid), "nonce": f"{time.time_ns()}", - "status_message": "Success", - "status_code": "100", + "status_code": 100, } ) @@ -1263,7 +1243,9 @@ async def verify(self, synapse: bittensor.Synapse): # We raise an exception to stop the process and return the error to the requester. # The error message includes the original exception message. - raise NotVerifiedException(f"Not Verified with error: {str(e)}") + raise NotVerifiedException( + f"Not Verified with error: {str(e)}", synapse=synapse + ) async def blacklist(self, synapse: bittensor.Synapse): """ @@ -1317,7 +1299,9 @@ async def blacklist(self, synapse: bittensor.Synapse): raise Exception("Synapse.axon object is None") # We raise an exception to halt the process and return the error message to the requester. - raise BlacklistedException(f"Forbidden. Key is blacklisted: {reason}.") + raise BlacklistedException( + f"Forbidden. Key is blacklisted: {reason}.", synapse=synapse + ) async def priority(self, synapse: bittensor.Synapse): """ @@ -1380,7 +1364,9 @@ async def submit_task( synapse.axon.status_code = 408 # Raise an exception to stop the process and return an appropriate error message to the requester. - raise PriorityException(f"Response timeout after: {synapse.timeout}s") + raise PriorityException( + f"Response timeout after: {synapse.timeout}s", synapse=synapse + ) async def run( self, @@ -1410,32 +1396,22 @@ async def run( response = await call_next(request) except Exception as e: - # If an exception occurs during the execution of the requested function, - # it is caught and handled here. - # Log the exception for debugging purposes. bittensor.logging.trace(f"Run exception: {str(e)}") - - # Set the status code of the synapse to "500" which indicates an internal server error. - if synapse.axon is not None: - synapse.axon.status_code = 500 - - # Raise an exception to stop the process and return an appropriate error message to the requester. - raise RunException(f"Internal server error with error: {str(e)}") + raise # Return the starlet response return response - async def postprocess( - self, synapse: bittensor.Synapse, response: Response, start_time: float - ) -> Response: + @classmethod + async def synapse_to_response( + cls, synapse: bittensor.Synapse, start_time: float + ) -> JSONResponse: """ - Performs the final processing on the response before sending it back to the client. This method - updates the response headers and logs the end of the request processing. + Converts the Synapse object into a JSON response with HTTP headers. Args: synapse (bittensor.Synapse): The Synapse object representing the request. - response (Response): The response generated by processing the request. start_time (float): The timestamp when the request processing started. Returns: @@ -1444,24 +1420,37 @@ async def postprocess( Postprocessing is the last step in the request handling process, ensuring that the response is properly formatted and contains all necessary information. """ - # Set the status code of the synapse to "200" which indicates a successful response. - if synapse.axon is not None: + if synapse.axon is None: + synapse.axon = bittensor.TerminalInfo() + + if synapse.axon.status_code is None: synapse.axon.status_code = 200 - # Set the status message of the synapse to "Success". + if synapse.axon.status_code == 200 and not synapse.axon.status_message: synapse.axon.status_message = "Success" + synapse.axon.process_time = time.time() - start_time + + serialized_synapse = await serialize_response(response_content=synapse) + response = JSONResponse( + status_code=synapse.axon.status_code, + content=serialized_synapse, + ) + try: - # Update the response headers with the headers from the synapse. updated_headers = synapse.to_headers() - response.headers.update(updated_headers) except Exception as e: - # If there is an exception during the response header update, we log the exception. raise PostProcessException( - f"Error while parsing or updating response headers. Postprocess exception: {str(e)}." - ) + f"Error while parsing response headers. Postprocess exception: {str(e)}.", + synapse=synapse, + ) from e - # Calculate the processing time by subtracting the start time from the current time. - synapse.axon.process_time = str(time.time() - start_time) # type: ignore + try: + response.headers.update(updated_headers) + except Exception as e: + raise PostProcessException( + f"Error while updating response headers. Postprocess exception: {str(e)}.", + synapse=synapse, + ) from e return response diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 130e3a7d42..47a3ba6f95 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -708,6 +708,12 @@ def process_server_response( except: # Ignore errors during attribute setting pass + else: + # If the server responded with an error, update the local synapse state + if local_synapse.axon is None: + local_synapse.axon = bittensor.TerminalInfo() + local_synapse.axon.status_code = server_response.status + local_synapse.axon.status_message = json_response.get("message") # Extract server headers and overwrite None values in local synapse headers server_headers = bittensor.Synapse.from_headers(server_response.headers) # type: ignore diff --git a/bittensor/errors.py b/bittensor/errors.py index de51b5d48a..b8366ee681 100644 --- a/bittensor/errors.py +++ b/bittensor/errors.py @@ -14,6 +14,12 @@ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + import bittensor class ChainError(BaseException): @@ -112,7 +118,16 @@ class InvalidRequestNameError(Exception): pass -class UnknownSynapseError(Exception): +class SynapseException(Exception): + def __init__( + self, message="Synapse Exception", synapse: "bittensor.Synapse" | None = None + ): + self.message = message + self.synapse = synapse + super().__init__(self.message) + + +class UnknownSynapseError(SynapseException): r"""This exception is raised when the request name is not found in the Axon's forward_fns dictionary.""" pass @@ -124,43 +139,47 @@ class SynapseParsingError(Exception): pass -class NotVerifiedException(Exception): +class NotVerifiedException(SynapseException): r"""This exception is raised when the request is not verified.""" pass -class BlacklistedException(Exception): +class BlacklistedException(SynapseException): r"""This exception is raised when the request is blacklisted.""" pass -class PriorityException(Exception): +class PriorityException(SynapseException): r"""This exception is raised when the request priority is not met.""" pass -class PostProcessException(Exception): +class PostProcessException(SynapseException): r"""This exception is raised when the response headers cannot be updated.""" pass -class RunException(Exception): +class RunException(SynapseException): r"""This exception is raised when the requested function cannot be executed. Indicates a server error.""" pass -class InternalServerError(Exception): +class InternalServerError(SynapseException): r"""This exception is raised when the requested function fails on the server. Indicates a server error.""" pass -class SynapseDendriteNoneException(Exception): - def __init__(self, message="Synapse Dendrite is None"): +class SynapseDendriteNoneException(SynapseException): + def __init__( + self, + message="Synapse Dendrite is None", + synapse: "bittensor.Synapse" | None = None, + ): self.message = message - super().__init__(self.message) + super().__init__(self.message, synapse) diff --git a/requirements/dev.txt b/requirements/dev.txt index c87082dbf0..2fe7007484 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -12,4 +12,6 @@ mypy==1.8.0 types-retry==0.9.9.4 freezegun==1.5.0 torch>=1.13.1 +httpx==0.27.0 +aioresponses==0.7.6 factory-boy==3.3.0 diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py new file mode 100644 index 0000000000..21a091f7af --- /dev/null +++ b/tests/unit_tests/conftest.py @@ -0,0 +1,8 @@ +import pytest +from aioresponses import aioresponses + + +@pytest.fixture +def mock_aioresponse(): + with aioresponses() as m: + yield m diff --git a/tests/unit_tests/test_axon.py b/tests/unit_tests/test_axon.py index 33a3724643..cc2eb8824b 100644 --- a/tests/unit_tests/test_axon.py +++ b/tests/unit_tests/test_axon.py @@ -18,17 +18,20 @@ # DEALINGS IN THE SOFTWARE. # Standard Lib -import pytest -import unittest +import re +from dataclasses import dataclass from typing import Any from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch # Third Party +import pytest from starlette.requests import Request +from fastapi.testclient import TestClient # Bittensor import bittensor +from bittensor import Synapse, RunException from bittensor.axon import AxonMiddleware from bittensor.axon import axon as Axon @@ -117,7 +120,7 @@ def test_log_and_handle_error(): synapse = log_and_handle_error(synapse, Exception("Error"), 500, 100) assert synapse.axon.status_code == 500 - assert synapse.axon.status_message == "Error" + assert re.match(r"Internal Server Error #[\da-f\-]+", synapse.axon.status_message) assert synapse.axon.process_time is not None @@ -161,15 +164,20 @@ def axon_instance(): # Mocks +@dataclass class MockWallet: - def __init__(self, hotkey): - self.hotkey = hotkey + hotkey: Any + coldkey: Any = None + coldkeypub: Any = None class MockHotkey: def __init__(self, ss58_address): self.ss58_address = ss58_address + def sign(self, *args, **kwargs): + return f"Signed: {args!r} {kwargs!r}".encode() + class MockInfo: def to_string(self): @@ -428,8 +436,8 @@ async def test_preprocess(self): assert synapse.axon.version == str(bittensor.__version_as_int__) assert synapse.axon.uuid == "1234" assert synapse.axon.nonce is not None - assert synapse.axon.status_message == "Success" - assert synapse.axon.status_code == "100" + assert synapse.axon.status_message is None + assert synapse.axon.status_code == 100 assert synapse.axon.signature == "0xaabbccdd" # Check if the preprocess function fills the dendrite information into the synapse @@ -440,5 +448,115 @@ async def test_preprocess(self): assert synapse.name == "request_name" -if __name__ == "__main__": - unittest.main() +class SynapseHTTPClient(TestClient): + def post_synapse(self, synapse: Synapse): + return self.post( + f"/{synapse.__class__.__name__}", + json=synapse.dict(), + headers={"computed_body_hash": synapse.body_hash}, + ) + + +@pytest.mark.asyncio +class TestAxonHTTPAPIResponses: + @pytest.fixture + def axon(self): + return Axon( + ip="192.0.2.1", + external_ip="192.0.2.1", + wallet=MockWallet(MockHotkey("A"), MockHotkey("B"), MockHotkey("PUB")), + ) + + @pytest.fixture + def no_verify_axon(self, axon): + axon.default_verify = self.no_verify_fn + return axon + + @pytest.fixture + def http_client(self, axon): + return SynapseHTTPClient(axon.app) + + async def no_verify_fn(self, synapse): + return + + async def test_unknown_path(self, http_client): + response = http_client.get("/no_such_path") + assert (response.status_code, response.json()) == ( + 404, + { + "message": "Synapse name 'no_such_path' not found. Available synapses ['Synapse']" + }, + ) + + async def test_ping__no_dendrite(self, http_client): + response = http_client.post_synapse(bittensor.Synapse()) + assert (response.status_code, response.json()) == ( + 401, + { + "message": "Not Verified with error: No SS58 formatted address or public key provided" + }, + ) + + async def test_ping__without_verification(self, http_client, axon): + axon.verify_fns["Synapse"] = self.no_verify_fn + request_synapse = Synapse() + response = http_client.post_synapse(request_synapse) + assert response.status_code == 200 + response_synapse = Synapse(**response.json()) + assert response_synapse.axon.status_code == 200 + + @pytest.fixture + def custom_synapse_cls(self): + class CustomSynapse(Synapse): + pass + + return CustomSynapse + + async def test_synapse__explicitly_set_status_code( + self, http_client, axon, custom_synapse_cls, no_verify_axon + ): + error_message = "Essential resource for CustomSynapse not found" + + async def forward_fn(synapse: custom_synapse_cls): + synapse.axon.status_code = 404 + synapse.axon.status_message = error_message + return synapse + + axon.attach(forward_fn) + + response = http_client.post_synapse(custom_synapse_cls()) + assert response.status_code == 404 + response_synapse = custom_synapse_cls(**response.json()) + assert ( + response_synapse.axon.status_code, + response_synapse.axon.status_message, + ) == (404, error_message) + + async def test_synapse__exception_with_set_status_code( + self, http_client, axon, custom_synapse_cls, no_verify_axon + ): + error_message = "Conflicting request" + + async def forward_fn(synapse: custom_synapse_cls): + synapse.axon.status_code = 409 + raise RunException(message=error_message, synapse=synapse) + + axon.attach(forward_fn) + + response = http_client.post_synapse(custom_synapse_cls()) + assert response.status_code == 409 + assert response.json() == {"message": error_message} + + async def test_synapse__internal_error( + self, http_client, axon, custom_synapse_cls, no_verify_axon + ): + async def forward_fn(synapse: custom_synapse_cls): + raise ValueError("error with potentially sensitive information") + + axon.attach(forward_fn) + + response = http_client.post_synapse(custom_synapse_cls()) + assert response.status_code == 500 + response_data = response.json() + assert sorted(response_data.keys()) == ["message"] + assert re.match(r"Internal Server Error #[\da-f\-]+", response_data["message"]) diff --git a/tests/unit_tests/test_dendrite.py b/tests/unit_tests/test_dendrite.py index 9b0b6d7ddf..36ccb2ecb2 100644 --- a/tests/unit_tests/test_dendrite.py +++ b/tests/unit_tests/test_dendrite.py @@ -46,6 +46,23 @@ def setup_dendrite(): return dendrite_obj +@pytest.fixture +def dendrite_obj(setup_dendrite): + return setup_dendrite + + +@pytest.fixture +def axon_info(): + return bittensor.AxonInfo( + version=1, + ip="127.0.0.1", + port=666, + ip_type=4, + hotkey="hot", + coldkey="cold", + ) + + @pytest.fixture(scope="session") def setup_axon(): axon = bittensor.axon() @@ -61,21 +78,18 @@ def test_init(setup_dendrite): assert dendrite_obj.keypair == setup_dendrite.keypair -def test_str(setup_dendrite): - dendrite_obj = setup_dendrite - expected_string = "dendrite({})".format(setup_dendrite.keypair.ss58_address) +def test_str(dendrite_obj): + expected_string = "dendrite({})".format(dendrite_obj.keypair.ss58_address) assert str(dendrite_obj) == expected_string -def test_repr(setup_dendrite): - dendrite_obj = setup_dendrite - expected_string = "dendrite({})".format(setup_dendrite.keypair.ss58_address) +def test_repr(dendrite_obj): + expected_string = "dendrite({})".format(dendrite_obj.keypair.ss58_address) assert repr(dendrite_obj) == expected_string -def test_close(setup_dendrite, setup_axon): +def test_close(dendrite_obj, setup_axon): axon = setup_axon - dendrite_obj = setup_dendrite # Query the axon to open a session dendrite_obj.query(axon, SynapseDummy(input=1)) # Session should be automatically closed after query @@ -83,9 +97,8 @@ def test_close(setup_dendrite, setup_axon): @pytest.mark.asyncio -async def test_aclose(setup_dendrite, setup_axon): +async def test_aclose(dendrite_obj, setup_axon): axon = setup_axon - dendrite_obj = setup_dendrite # Use context manager to open an async session async with dendrite_obj: resp = await dendrite_obj([axon], SynapseDummy(input=1), deserialize=False) @@ -272,3 +285,52 @@ def test_terminal_info_error_cases( version=version, nonce=nonce, ) + + +@pytest.mark.asyncio +async def test_dendrite__call__success_response( + axon_info, dendrite_obj, mock_aioresponse +): + input_synapse = SynapseDummy(input=1) + expected_synapse = SynapseDummy( + **( + input_synapse.dict() + | dict( + output=2, + axon=TerminalInfo( + status_code=200, + status_message="Success", + process_time=0.1, + ), + ) + ) + ) + mock_aioresponse.post( + f"http://127.0.0.1:666/SynapseDummy", + body=expected_synapse.json(), + ) + synapse = await dendrite_obj.call(axon_info, synapse=input_synapse) + + assert synapse.input == 1 + assert synapse.output == 2 + assert synapse.dendrite.status_code == 200 + assert synapse.dendrite.status_message == "Success" + assert synapse.dendrite.process_time >= 0 + + +@pytest.mark.asyncio +async def test_dendrite__call__handles_http_error_response( + axon_info, dendrite_obj, mock_aioresponse +): + status_code = 414 + message = "Custom Error" + + mock_aioresponse.post( + f"http://127.0.0.1:666/SynapseDummy", + status=status_code, + payload={"message": message}, + ) + synapse = await dendrite_obj.call(axon_info, synapse=SynapseDummy(input=1)) + + assert synapse.axon.status_code == synapse.dendrite.status_code == status_code + assert synapse.axon.status_message == synapse.dendrite.status_message == message