Skip to content

Commit

Permalink
[commands] Don't raise ExtensionNotFound for ImportErrors in modules
Browse files Browse the repository at this point in the history
Now loading an extension that _contains_ a failed import will fail
with ExtensionFailed, rather than ExtensionNotFound.
  • Loading branch information
ioistired authored and Rapptz committed Jun 29, 2019
1 parent 3961e7e commit 0a21591
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
24 changes: 14 additions & 10 deletions discord/ext/commands/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import asyncio
import collections
import inspect
import importlib
import importlib.util
import sys
import traceback
import re
Expand Down Expand Up @@ -588,12 +588,17 @@ def _call_module_finalizers(self, lib, key):
if _is_submodule(name, module):
del sys.modules[module]

def _load_from_module_spec(self, lib, key):
def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(lib)
except Exception as e:
raise errors.ExtensionFailed(key, e) from e

try:
setup = getattr(lib, 'setup')
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)

try:
Expand All @@ -603,7 +608,7 @@ def _load_from_module_spec(self, lib, key):
self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
sys.modules[key] = self.__extensions[key] = lib

def load_extension(self, name):
"""Loads an extension.
Expand Down Expand Up @@ -637,12 +642,11 @@ def load_extension(self, name):
if name in self.__extensions:
raise errors.ExtensionAlreadyLoaded(name)

try:
lib = importlib.import_module(name)
except ImportError as e:
raise errors.ExtensionNotFound(name, e) from e
else:
self._load_from_module_spec(lib, name)
spec = importlib.util.find_spec(name)
if spec is None:
raise errors.ExtensionNotFound(name)

self._load_from_module_spec(spec, name)

def unload_extension(self, name):
"""Unloads an extension.
Expand Down
16 changes: 9 additions & 7 deletions discord/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def __init__(self, name):
super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name)

class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the ``setup`` entry point.
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
This inherits from :exc:`ExtensionError`
Expand All @@ -521,19 +521,21 @@ def __init__(self, name, original):
super().__init__(fmt.format(name, original), name=name)

class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension failed to be imported.
"""An exception raised when an extension is not found.
This inherits from :exc:`ExtensionError`
.. versionchanged:: 1.3.0
Made the ``original`` attribute always None.
Attributes
-----------
name: :class:`str`
The extension that had the error.
original: :exc:`ImportError`
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
original: :class:`NoneType`
Always ``None`` for backwards compatibility.
"""
def __init__(self, name, original):
self.original = original
def __init__(self, name, original=None):
self.original = None
fmt = 'Extension {0!r} could not be loaded.'
super().__init__(fmt.format(name), name=name)

0 comments on commit 0a21591

Please sign in to comment.