diff --git a/altair/utils/compiler.py b/altair/utils/compiler.py index 0944c92fd..cbe0dd62d 100644 --- a/altair/utils/compiler.py +++ b/altair/utils/compiler.py @@ -7,5 +7,5 @@ VegaLiteCompilerType = Callable[[dict], dict] -class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType]): +class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType, dict]): pass diff --git a/altair/utils/data.py b/altair/utils/data.py index daf37393d..920b8bd83 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -16,8 +16,9 @@ Dict, overload, runtime_checkable, + Callable, ) -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, ParamSpec, Concatenate from pathlib import Path from functools import partial import sys @@ -82,17 +83,14 @@ def is_data_type(obj: Any) -> TypeIs[DataType]: # VegaLite spec, after the Data model has been put into a schema compliant # form. # ============================================================================== -class DataTransformerType(Protocol): - @overload - def __call__(self, data: None = None, **kwargs) -> DataTransformerType: ... - @overload - def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: ... - def __call__( - self, data: DataType | None = None, **kwargs - ) -> DataTransformerType | VegaLiteDataDict: ... +P = ParamSpec("P") +# NOTE: `Any` required due to the complexity of existing signatures imported in `altair.vegalite.v5.data.py` +R = TypeVar("R", VegaLiteDataDict, Any) +DataTransformerType = Callable[Concatenate[DataType, P], R] -class DataTransformerRegistry(PluginRegistry[DataTransformerType]): + +class DataTransformerRegistry(PluginRegistry[DataTransformerType, R]): _global_settings = {"consolidate_datasets": True} @property diff --git a/altair/utils/display.py b/altair/utils/display.py index 2078c61de..2eee30564 100644 --- a/altair/utils/display.py +++ b/altair/utils/display.py @@ -31,7 +31,7 @@ ] -class RendererRegistry(PluginRegistry[RendererType]): +class RendererRegistry(PluginRegistry[RendererType, MimeBundleType]): entrypoint_err_messages = { "notebook": textwrap.dedent( """ diff --git a/altair/utils/plugin_registry.py b/altair/utils/plugin_registry.py index 53853f174..e81d79649 100644 --- a/altair/utils/plugin_registry.py +++ b/altair/utils/plugin_registry.py @@ -1,15 +1,33 @@ from __future__ import annotations from functools import partial -from typing import Any, Generic, TypeVar, cast, Callable, TYPE_CHECKING +from typing import Any, Generic, cast, Callable, TYPE_CHECKING +from typing_extensions import TypeAliasType, TypeVar, TypeIs from importlib.metadata import entry_points +from altair.utils.deprecation import deprecated_warn + if TYPE_CHECKING: from types import TracebackType +T = TypeVar("T") +R = TypeVar("R") +Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,)) +PluginT = TypeVar("PluginT", bound=Plugin[Any]) +IsPlugin = Callable[[object], TypeIs[Plugin[Any]]] + + +def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]: + """Converts a type to guard function. + + Added for compatibility with original `PluginRegistry` default. + """ + + def func(obj: object, /) -> TypeIs[type[T]]: + return isinstance(obj, tp) -PluginType = TypeVar("PluginType") + return func class NoSuchEntryPoint(Exception): @@ -49,7 +67,7 @@ def __repr__(self) -> str: return f"{self.registry.__class__.__name__}.enable({self.name!r})" -class PluginRegistry(Generic[PluginType]): +class PluginRegistry(Generic[PluginT, R]): """A registry for plugins. This is a plugin registry that allows plugins to be loaded/registered @@ -74,26 +92,44 @@ class PluginRegistry(Generic[PluginType]): # in the registry rather than passed to the plugins _global_settings: dict[str, Any] = {} - def __init__(self, entry_point_group: str = "", plugin_type: type = Callable): # type: ignore[assignment] + def __init__( + self, entry_point_group: str = "", plugin_type: IsPlugin = callable + ) -> None: """Create a PluginRegistry for a named entry point group. Parameters ========== entry_point_group: str The name of the entry point group. - plugin_type: object - A type that will optionally be used for runtime type checking of - loaded plugins using isinstance. + plugin_type + A type narrowing function that will optionally be used for runtime + type checking loaded plugins. + + References + ========== + https://typing.readthedocs.io/en/latest/spec/narrowing.html """ self.entry_point_group: str = entry_point_group - self.plugin_type: type[Any] = plugin_type - self._active: PluginType | None = None + self.plugin_type: IsPlugin + if plugin_type is not callable and isinstance(plugin_type, type): + msg = ( + f"Pass a callable `TypeIs` function to `plugin_type` instead.\n" + f"{type(self).__name__!r}(plugin_type)\n\n" + f"See also:\n" + f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n" + f"https://docs.astral.sh/ruff/rules/assert/" + ) + deprecated_warn(msg, version="5.4.0") + self.plugin_type = cast(IsPlugin, _is_type(plugin_type)) + else: + self.plugin_type = plugin_type + self._active: Plugin[R] | None = None self._active_name: str = "" - self._plugins: dict[str, PluginType] = {} + self._plugins: dict[str, PluginT] = {} self._options: dict[str, Any] = {} self._global_settings: dict[str, Any] = self.__class__._global_settings.copy() - def register(self, name: str, value: PluginType | Any | None) -> PluginType | None: + def register(self, name: str, value: PluginT | None) -> PluginT | None: """Register a plugin by name and value. This method is used for explicit registration of a plugin and shouldn't be @@ -113,12 +149,12 @@ def register(self, name: str, value: PluginType | Any | None) -> PluginType | No """ if value is None: return self._plugins.pop(name, None) - else: - assert isinstance( - value, self.plugin_type - ) # Should ideally be fixed by better annotating plugin_type + elif self.plugin_type(value): self._plugins[name] = value return value + else: + msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}" + raise TypeError(msg) def names(self) -> list[str]: """List the names of the registered and entry points plugins.""" @@ -163,7 +199,7 @@ def _enable(self, name: str, **options) -> None: raise ValueError(self.entrypoint_err_messages[name]) from err else: raise NoSuchEntryPoint(self.entry_point_group, name) from err - value = cast(PluginType, ep.load()) + value = cast(PluginT, ep.load()) self.register(name, value) self._active_name = name self._active = self._plugins[name] @@ -204,18 +240,21 @@ def options(self) -> dict[str, Any]: """Return the current options dictionary""" return self._options - def get(self) -> PluginType | Callable[..., Any] | None: + def get(self) -> partial[R] | Plugin[R] | None: """Return the currently active plugin.""" - if self._options: - if func := self._active: - # NOTE: Fully do not understand this one - # error: Argument 1 to "partial" has incompatible type "PluginType"; expected "Callable[..., Never]" - return partial(func, **self._options) # type: ignore[arg-type] - else: - msg = "Unclear what this meant by passing to curry." - raise TypeError(msg) - else: - return self._active + if (func := self._active) and self.plugin_type(func): + return partial(func, **self._options) if self._options else func + elif self._active is not None: + msg = ( + f"{type(self).__name__!r} requires all plugins to be callable objects, " + f"but {type(self._active).__name__!r} is not callable." + ) + raise TypeError(msg) + elif TYPE_CHECKING: + # NOTE: The `None` return is implicit, but `mypy` isn't satisfied + # - `ruff` will factor out explicit `None` return + # - `pyright` has no issue + raise NotImplementedError def __repr__(self) -> str: return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})" @@ -228,6 +267,6 @@ def importlib_metadata_get(group): # also get compatibility with the importlib_metadata package which had a different # deprecation cycle for 'get' if hasattr(ep, "select"): - return ep.select(group=group) + return ep.select(group=group) # pyright: ignore else: return ep.get(group, []) diff --git a/altair/utils/theme.py b/altair/utils/theme.py index 10dc6fa8a..4ad0209da 100644 --- a/altair/utils/theme.py +++ b/altair/utils/theme.py @@ -6,5 +6,5 @@ ThemeType = Callable[..., dict] -class ThemeRegistry(PluginRegistry[ThemeType]): +class ThemeRegistry(PluginRegistry[ThemeType, dict]): pass diff --git a/altair/vegalite/v5/data.py b/altair/vegalite/v5/data.py index fbc339a46..3b69a92c5 100644 --- a/altair/vegalite/v5/data.py +++ b/altair/vegalite/v5/data.py @@ -25,7 +25,8 @@ data_transformers = DataTransformerRegistry(entry_point_group=ENTRY_POINT_GROUP) data_transformers.register("default", default_data_transformer) data_transformers.register("json", to_json) -data_transformers.register("csv", to_csv) +# FIXME: `to_csv` cannot accept all `DataType` https://github.com/vega/altair/issues/3441 +data_transformers.register("csv", to_csv) # type: ignore[arg-type] data_transformers.register("vegafusion", vegafusion_data_transformer) data_transformers.enable("default") diff --git a/tests/utils/test_plugin_registry.py b/tests/utils/test_plugin_registry.py index b776a118a..632cad027 100644 --- a/tests/utils/test_plugin_registry.py +++ b/tests/utils/test_plugin_registry.py @@ -2,7 +2,7 @@ from typing import Callable -class TypedCallableRegistry(PluginRegistry[Callable[[int], int]]): +class TypedCallableRegistry(PluginRegistry[Callable[[int], int], int]): pass