diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 2e6b957e5b..0519dda253 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -237,20 +237,44 @@ def networks(self) -> dict[str, "NetworkAPI"]: networks = {**self._networks_from_plugins} # Include configured custom networks. - custom_networks = [ + custom_networks: list = [ n for n in self.config_manager.get_config("networks").custom if (n.ecosystem or self.network_manager.default_ecosystem.name) == self.name ] + + # Ensure forks are added automatically for custom networks. + forked_custom_networks = [] + for net in custom_networks: + if net.name.endswith("-fork"): + # Already a fork. + continue + + fork_network_name = f"{net.name}-fork" + if any(x.name == fork_network_name for x in custom_networks): + # The forked version of this network is already known. + continue + + # Create a forked network mirroring the custom network. + forked_net = net.model_copy(deep=True) + forked_net.name = fork_network_name + forked_custom_networks.append(forked_net) + + # NOTE: Forked custom networks are still custom networks. + custom_networks.extend(forked_custom_networks) + for custom_net in custom_networks: if custom_net.name in networks: raise NetworkError( f"More than one network named '{custom_net.name}' in ecosystem '{self.name}'." ) + is_fork = custom_net.name.endswith("-fork") network_data = custom_net.model_dump(by_alias=True, exclude=("default_provider",)) network_data["ecosystem"] = self - network_type = create_network_type(custom_net.chain_id, custom_net.chain_id) + network_type = create_network_type( + custom_net.chain_id, custom_net.chain_id, is_fork=is_fork + ) network_api = network_type.model_validate(network_data) network_api._default_provider = custom_net.default_provider network_api._is_custom = True @@ -758,6 +782,17 @@ def disconnect_all(self): self.connected_providers = {} +def _set_provider(provider: "ProviderAPI") -> "ProviderAPI": + connection_id = provider.connection_id + if connection_id in ProviderContextManager.connected_providers: + # Likely multi-chain testing or utilizing multiple on-going connections. + provider = ProviderContextManager.connected_providers[connection_id] + if not provider.is_connected: + provider.connect() + + return provider + + class NetworkAPI(BaseInterfaceModel): """ A wrapper around a provider for a specific ecosystem. @@ -1037,26 +1072,27 @@ def get_provider( provider_settings["ipc_path"] = provider_name provider_name = "node" + # Assuming any installed forking plugin can at least for Ethereum mainnet. + # NOTE: This is a bit limiting for non-EVM custom forked networks. + common_forking_providers = self.network_manager.ethereum.mainnet_fork.providers + if provider_name in self.providers: provider = self.providers[provider_name](provider_settings=provider_settings) - connection_id = provider.connection_id - if connection_id in ProviderContextManager.connected_providers: - # Likely multi-chain testing or utilizing multiple on-going connections. - provider = ProviderContextManager.connected_providers[connection_id] - if not provider.is_connected: - provider.connect() + return _set_provider(provider) - return provider - - return provider - - else: - raise ProviderNotFoundError( - provider_name, - network=self.name, - ecosystem=self.ecosystem.name, - options=self.providers, + elif self.name.endswith("-fork") and provider_name in common_forking_providers: + provider = common_forking_providers[provider_name]( + provider_settings=provider_settings, + network=self, ) + return _set_provider(provider) + + raise ProviderNotFoundError( + provider_name, + network=self.name, + ecosystem=self.ecosystem.name, + options=self.providers, + ) def use_provider( self, @@ -1276,12 +1312,13 @@ def use_upstream_provider(self) -> ProviderContextManager: return self.upstream_network.use_provider(self.upstream_provider) -def create_network_type(chain_id: int, network_id: int) -> type[NetworkAPI]: +def create_network_type(chain_id: int, network_id: int, is_fork: bool = False) -> type[NetworkAPI]: """ Easily create a :class:`~ape.api.networks.NetworkAPI` subclass. """ + BaseNetwork = ForkedNetworkAPI if is_fork else NetworkAPI - class network_def(NetworkAPI): + class network_def(BaseNetwork): # type: ignore @property def chain_id(self) -> int: return chain_id diff --git a/src/ape/cli/choices.py b/src/ape/cli/choices.py index 81558be1e3..698485fbdf 100644 --- a/src/ape/cli/choices.py +++ b/src/ape/cli/choices.py @@ -335,15 +335,15 @@ def __init__( self.base_type = base_type self.callback = callback - super().__init__( - get_networks(ecosystem=ecosystem, network=network, provider=provider), case_sensitive - ) + networks = get_networks(ecosystem=ecosystem, network=network, provider=provider) + super().__init__(networks, case_sensitive) def get_metavar(self, param): return "[ecosystem-name][:[network-name][:[provider-name]]]" def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context]) -> Any: choice: Optional[Union[str, ProviderAPI]] + networks = ManagerAccessMixin.network_manager if not value: choice = None @@ -360,6 +360,14 @@ def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context] # Validate result. choice = super().convert(value, param, ctx) except BadParameter as err: + # Attempt to get the provider anyway. + # Some plugins will handle networks anyway, + # such as forked-custom networks. + try: + return networks.get_provider_from_choice(network_choice=value) + except Exception: + pass # Pretend this never happened and raise BadParam. + # If an error was not raised for some reason, raise a simpler error. # NOTE: Still avoid showing the massive network options list. raise click.BadParameter( @@ -372,10 +380,7 @@ def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context] and issubclass(self.base_type, ProviderAPI) ): # Return the provider. - - choice = ManagerAccessMixin.network_manager.get_provider_from_choice( - network_choice=value - ) + choice = networks.get_provider_from_choice(network_choice=value) return self.callback(ctx, param, choice) if self.callback else choice diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index d6dc5dabcd..3e492d2f1c 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -462,6 +462,7 @@ def get_provider_from_choice( self.get_ecosystem(ecosystem_name) if ecosystem_name else self.default_ecosystem ) network = ecosystem.get_network(network_name or ecosystem.default_network_name) + return network.get_provider( provider_name=provider_name, provider_settings=provider_settings )