diff --git a/altair/expr/core.py b/altair/expr/core.py index cfaa9984c..52fabc1fe 100644 --- a/altair/expr/core.py +++ b/altair/expr/core.py @@ -1,5 +1,9 @@ from __future__ import annotations -from typing import Any + +from typing import Any, Union, Dict + +from typing_extensions import TypeAlias + from ..utils import SchemaBase @@ -232,3 +236,6 @@ def __init__(self, group, name) -> None: def __repr__(self) -> str: return f"{self.group}[{self.name!r}]" + + +IntoExpression: TypeAlias = Union[bool, None, str, OperatorMixin, Dict[str, Any]] diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 5847ad94f..ceae9fa34 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from altair.utils.core import DataFrameLike + from altair.vegalite.v5.api import ChartType Scope: TypeAlias = Tuple[int, ...] FacetMapping: TypeAlias = Dict[Tuple[str, Scope], Tuple[str, Scope]] @@ -154,9 +155,7 @@ def transformed_data(chart, row_limit=None, exclude=None): # The same error appeared when trying it with Protocols for the concat and layer charts. # This function is only used internally and so we accept this inconsistency for now. def name_views( - chart: Chart | FacetChart | LayerChart | HConcatChart | VConcatChart | ConcatChart, - i: int = 0, - exclude: Iterable[str] | None = None, + chart: ChartType, i: int = 0, exclude: Iterable[str] | None = None ) -> list[str]: """ Name unnamed chart views. @@ -193,6 +192,7 @@ def name_views( else: return [] else: + subcharts: list[Any] if isinstance(chart, _chart_class_mapping[LayerChart]): subcharts = chart.layer elif isinstance(chart, _chart_class_mapping[HConcatChart]): diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 26f79d22e..1642760cb 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -213,7 +213,7 @@ def compile_to_vegafusion_chart_state( return chart_state -def compile_with_vegafusion(vegalite_spec: dict[str, Any]) -> dict: +def compile_with_vegafusion(vegalite_spec: dict[str, Any]) -> dict[str, Any]: """ Compile a Vega-Lite spec to Vega and pre-transform with VegaFusion. diff --git a/altair/utils/compiler.py b/altair/utils/compiler.py index cbe0dd62d..634260715 100644 --- a/altair/utils/compiler.py +++ b/altair/utils/compiler.py @@ -1,11 +1,11 @@ -from typing import Callable +from typing import Callable, Dict, Any from altair.utils import PluginRegistry # ============================================================================== # Vega-Lite to Vega compiler registry # ============================================================================== -VegaLiteCompilerType = Callable[[dict], dict] +VegaLiteCompilerType = Callable[[Dict[str, Any]], Dict[str, Any]] -class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType, dict]): +class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType, Dict[str, Any]]): pass diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 9458442ca..33bb5c019 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -15,6 +15,7 @@ TYPE_CHECKING, TypeVar, Protocol, + Sequence, ) from typing_extensions import TypeAlias import typing as t @@ -103,12 +104,19 @@ TopLevelSelectionParameter, SelectionParameter, InlineDataset, + NamedData, + InlineData, + BindCheckbox, + BindRadioSelect, + BindRange, ) + from altair.utils.display import MimeBundleType from altair.expr.core import ( BinaryExpression, Expression, GetAttrExpression, GetItemExpression, + IntoExpression, ) from .schema._typing import ( ImputeMethod_T, @@ -122,6 +130,9 @@ MultiTimeUnit_T, SingleTimeUnit_T, OneOrSeq, + AutosizeType_T, + ColorName_T, + LayoutAlign_T, ) __all__ = [ @@ -183,7 +194,7 @@ # ------------------------------------------------------------------------ # Data Utilities -def _dataset_name(values: dict | list | InlineDataset) -> str: +def _dataset_name(values: dict[str, Any] | list | InlineDataset) -> str: """ Generate a unique hash of the data. @@ -206,7 +217,9 @@ def _dataset_name(values: dict | list | InlineDataset) -> str: return "data-" + hsh -def _consolidate_data(data: Any, context: Any) -> Any: +def _consolidate_data( + data: ChartDataType | UrlData, context: dict[str, Any] +) -> ChartDataType | NamedData | InlineData | UrlData: """ If data is specified inline, then move it to context['datasets']. @@ -235,7 +248,9 @@ def _consolidate_data(data: Any, context: Any) -> Any: return data -def _prepare_data(data, context=None): +def _prepare_data( + data: ChartDataType, context: dict[str, Any] | None = None +) -> ChartDataType | NamedData | InlineData | UrlData | Any: """ Convert input data to data for use within schema. @@ -281,10 +296,10 @@ def _prepare_data(data, context=None): class LookupData(core.LookupData): @utils.use_signature(core.LookupData) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs) -> dict: + def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Convert the chart to a dictionary suitable for JSON export.""" copy = self.copy(deep=False) copy.data = _prepare_data(copy.data, kwargs.get("context")) @@ -295,10 +310,10 @@ class FacetMapping(core.FacetMapping): _class_is_valid_at_instantiation = False @utils.use_signature(core.FacetMapping) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs) -> dict: + def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: copy = self.copy(deep=False) context = kwargs.get("context", {}) data = context.get("data", None) @@ -353,11 +368,11 @@ def __init__( alternative="to_dict", message="No need to call '.ref()' anymore.", ) - def ref(self) -> dict: + def ref(self) -> dict[str, Any]: """'ref' is deprecated. No need to call '.ref()' anymore.""" return self.to_dict() - def to_dict(self) -> dict[str, str | dict]: + def to_dict(self) -> dict[str, str | dict[str, Any]]: if self.param_type == "variable": return {"expr": self.name} elif self.param_type == "selection": @@ -367,13 +382,13 @@ def to_dict(self) -> dict[str, str | dict]: msg = f"Unrecognized parameter type: {self.param_type}" raise ValueError(msg) - def __invert__(self): + def __invert__(self) -> SelectionPredicateComposition | Any: if self.param_type == "selection": return SelectionPredicateComposition({"not": {"param": self.name}}) else: return _expr_core.OperatorMixin.__invert__(self) - def __and__(self, other): + def __and__(self, other: Any) -> SelectionPredicateComposition | Any: if self.param_type == "selection": if isinstance(other, Parameter): other = {"param": other.name} @@ -381,7 +396,7 @@ def __and__(self, other): else: return _expr_core.OperatorMixin.__and__(self, other) - def __or__(self, other): + def __or__(self, other: Any) -> SelectionPredicateComposition | Any: if self.param_type == "selection": if isinstance(other, Parameter): other = {"param": other.name} @@ -395,7 +410,7 @@ def __repr__(self) -> str: def _to_expr(self) -> str: return self.name - def _from_expr(self, expr) -> ParameterExpression: + def _from_expr(self, expr: IntoExpression) -> ParameterExpression: return ParameterExpression(expr=expr) def __getattr__(self, field_name: str) -> GetAttrExpression | SelectionExpression: @@ -416,18 +431,18 @@ def __getitem__(self, field_name: str) -> GetItemExpression: # Enables use of ~, &, | with compositions of selection objects. class SelectionPredicateComposition(core.PredicateComposition): - def __invert__(self): + def __invert__(self) -> SelectionPredicateComposition: return SelectionPredicateComposition({"not": self.to_dict()}) - def __and__(self, other): + def __and__(self, other: SchemaBase) -> SelectionPredicateComposition: return SelectionPredicateComposition({"and": [self.to_dict(), other.to_dict()]}) - def __or__(self, other): + def __or__(self, other: SchemaBase) -> SelectionPredicateComposition: return SelectionPredicateComposition({"or": [self.to_dict(), other.to_dict()]}) class ParameterExpression(_expr_core.OperatorMixin): - def __init__(self, expr) -> None: + def __init__(self, expr: IntoExpression) -> None: self.expr = expr def to_dict(self) -> dict[str, str]: @@ -436,12 +451,12 @@ def to_dict(self) -> dict[str, str]: def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr) -> ParameterExpression: + def _from_expr(self, expr: IntoExpression) -> ParameterExpression: return ParameterExpression(expr=expr) class SelectionExpression(_expr_core.OperatorMixin): - def __init__(self, expr) -> None: + def __init__(self, expr: IntoExpression) -> None: self.expr = expr def to_dict(self) -> dict[str, str]: @@ -450,7 +465,7 @@ def to_dict(self) -> dict[str, str]: def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr) -> SelectionExpression: + def _from_expr(self, expr: IntoExpression) -> SelectionExpression: return SelectionExpression(expr=expr) @@ -1037,7 +1052,7 @@ def when( ) raise NotImplementedError(msg) - def to_dict(self, *args, **kwds) -> _Conditional[_C]: # type: ignore[override] + def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[override] m = super().to_dict(*args, **kwds) return _Conditional(condition=m["condition"]) @@ -1213,7 +1228,7 @@ def when( # Top-Level Functions -def value(value, **kwargs) -> _Value: +def value(value: Any, **kwargs: Any) -> _Value: """Specify a value for use in an encoding.""" return _Value(value=value, **kwargs) # type: ignore[typeddict-item] @@ -1224,7 +1239,7 @@ def param( bind: Optional[Binding] = Undefined, empty: Optional[bool] = Undefined, expr: Optional[str | Expr | Expression] = Undefined, - **kwds, + **kwds: Any, ) -> Parameter: """ Create a named parameter, see https://altair-viz.github.io/user_guide/interactions.html for examples. @@ -1305,7 +1320,7 @@ def param( return parameter -def _selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter: +def _selection(type: Optional[SelectionType_T] = Undefined, **kwds: Any) -> Parameter: # We separate out the parameter keywords from the selection keywords select_kwds = {"name", "bind", "value", "empty", "init", "views"} @@ -1335,7 +1350,7 @@ def _selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter alternative="'selection_point()' or 'selection_interval()'", message="These functions also include more helpful docstrings.", ) -def selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter: +def selection(type: Optional[SelectionType_T] = Undefined, **kwds: Any) -> Parameter: """'selection' is deprecated use 'selection_point' or 'selection_interval' instead, depending on the type of parameter you want to create.""" return _selection(type=type, **kwds) @@ -1353,7 +1368,7 @@ def selection_interval( mark: Optional[Mark] = Undefined, translate: Optional[str | bool] = Undefined, zoom: Optional[str | bool] = Undefined, - **kwds, + **kwds: Any, ) -> Parameter: """ Create an interval selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Interval selection parameters are used to select a continuous range of data values on drag, whereas point selection parameters (`selection_point`) are used to select multiple discrete data values.). @@ -1466,7 +1481,7 @@ def selection_point( resolve: Optional[SelectionResolution_T] = Undefined, toggle: Optional[str | bool] = Undefined, nearest: Optional[bool] = Undefined, - **kwds, + **kwds: Any, ) -> Parameter: """ Create a point selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Point selection parameters are used to select multiple discrete data values; the first value is selected on click and additional values toggled on shift-click. To select a continuous range of data values on drag interval selection parameters (`selection_interval`) can be used instead. @@ -1571,43 +1586,43 @@ def selection_point( @utils.deprecated(version="5.0.0", alternative="selection_point") -def selection_multi(**kwargs): +def selection_multi(**kwargs: Any) -> Parameter: """'selection_multi' is deprecated. Use 'selection_point'.""" return _selection(type="point", **kwargs) @utils.deprecated(version="5.0.0", alternative="selection_point") -def selection_single(**kwargs): +def selection_single(**kwargs: Any) -> Parameter: """'selection_single' is deprecated. Use 'selection_point'.""" return _selection(type="point", **kwargs) @utils.use_signature(core.Binding) -def binding(input, **kwargs): +def binding(input: Any, **kwargs: Any) -> Binding: """A generic binding.""" return core.Binding(input=input, **kwargs) @utils.use_signature(core.BindCheckbox) -def binding_checkbox(**kwargs): +def binding_checkbox(**kwargs: Any) -> BindCheckbox: """A checkbox binding.""" return core.BindCheckbox(input="checkbox", **kwargs) @utils.use_signature(core.BindRadioSelect) -def binding_radio(**kwargs): +def binding_radio(**kwargs: Any) -> BindRadioSelect: """A radio button binding.""" return core.BindRadioSelect(input="radio", **kwargs) @utils.use_signature(core.BindRadioSelect) -def binding_select(**kwargs): +def binding_select(**kwargs: Any) -> BindRadioSelect: """A select binding.""" return core.BindRadioSelect(input="select", **kwargs) @utils.use_signature(core.BindRange) -def binding_range(**kwargs): +def binding_range(**kwargs: Any) -> BindRange: """A range binding.""" return core.BindRange(input="range", **kwargs) @@ -1619,7 +1634,7 @@ def condition( if_false: _TSchemaBase, *, empty: Optional[bool] = ..., - **kwargs, + **kwargs: Any, ) -> _TSchemaBase: ... @overload def condition( @@ -1628,7 +1643,7 @@ def condition( if_false: Map | str, *, empty: Optional[bool] = ..., - **kwargs, + **kwargs: Any, ) -> dict[str, _ConditionType | Any]: ... @overload def condition( @@ -1637,11 +1652,11 @@ def condition( if_false: Map, *, empty: Optional[bool] = ..., - **kwargs, + **kwargs: Any, ) -> dict[str, _ConditionType | Any]: ... @overload def condition( - predicate: _PredicateType, if_true: str, if_false: str, **kwargs + predicate: _PredicateType, if_true: str, if_false: str, **kwargs: Any ) -> Never: ... # TODO: update the docstring def condition( @@ -1650,7 +1665,7 @@ def condition( if_false: _StatementType, *, empty: Optional[bool] = Undefined, - **kwargs, + **kwargs: Any, ) -> SchemaBase | dict[str, _ConditionType | Any]: """ A conditional attribute or encoding. @@ -1687,7 +1702,9 @@ def condition( # Top-level objects -def _top_schema_base(obj: Any, /): # -> +def _top_schema_base( # noqa: ANN202 + obj: Any, / +): # -> """ Enforces an intersection type w/ `SchemaBase` & `TopLevelMixin` objects. @@ -1713,7 +1730,7 @@ def to_dict( format: str = "vega-lite", ignore: list[str] | None = None, context: dict[str, Any] | None = None, - ) -> dict: + ) -> dict[str, Any]: """ Convert the chart to a dictionary suitable for JSON export. @@ -1840,7 +1857,7 @@ def to_json( format: str = "vega-lite", ignore: list[str] | None = None, context: dict[str, Any] | None = None, - **kwargs, + **kwargs: Any, ) -> str: """ Convert a chart to a JSON string. @@ -1885,7 +1902,7 @@ def to_html( fullhtml: bool = True, requirejs: bool = False, inline: bool = False, - **kwargs, + **kwargs: Any, ) -> str: """ Embed a Vega/Vega-Lite spec into an HTML page. @@ -1985,8 +2002,8 @@ def save( embed_options: dict | None = None, json_kwds: dict | None = None, engine: str | None = None, - inline=False, - **kwargs, + inline: bool = False, + **kwargs: Any, ) -> None: """ Save a chart to file in a variety of formats. @@ -2074,27 +2091,27 @@ def __repr__(self) -> str: return f"alt.{self.__class__.__name__}(...)" # Layering and stacking - def __add__(self, other) -> LayerChart: - if not isinstance(other, TopLevelMixin): + def __add__(self, other: ChartType) -> LayerChart: + if not is_chart_type(other): msg = "Only Chart objects can be layered." raise ValueError(msg) - return layer(self, other) + return layer(t.cast("ChartType", self), other) - def __and__(self, other) -> VConcatChart: - if not isinstance(other, TopLevelMixin): + def __and__(self, other: ChartType) -> VConcatChart: + if not is_chart_type(other): msg = "Only Chart objects can be concatenated." raise ValueError(msg) # Too difficult to type check this - return vconcat(self, other) + return vconcat(t.cast("ChartType", self), other) - def __or__(self, other) -> HConcatChart | ConcatChart: - if not isinstance(other, TopLevelMixin): + def __or__(self, other: ChartType) -> HConcatChart | ConcatChart: + if not is_chart_type(other): msg = "Only Chart objects can be concatenated." raise ValueError(msg) elif isinstance(self, ConcatChart): return concat(self, other) else: - return hconcat(self, other) + return hconcat(t.cast("ChartType", self), other) def repeat( self, @@ -2103,7 +2120,7 @@ def repeat( column: Optional[list[str]] = Undefined, layer: Optional[list[str]] = Undefined, columns: Optional[int] = Undefined, - **kwargs, + **kwargs: Any, ) -> RepeatChart: """ Return a RepeatChart built from the chart. @@ -2154,9 +2171,11 @@ def repeat( else: repeat_arg = core.RepeatMapping(row=row, column=column) - return RepeatChart(spec=self, repeat=repeat_arg, columns=columns, **kwargs) + return RepeatChart( + spec=t.cast("ChartType", self), repeat=repeat_arg, columns=columns, **kwargs + ) - def properties(self, **kwargs) -> Self: + def properties(self, **kwargs: Any) -> Self: """ Set top-level properties of the Chart. @@ -2204,7 +2223,7 @@ def project( translate: Optional[ list[float] | Vector2number | ExprRef | Parameter ] = Undefined, - **kwds, + **kwds: Any, ) -> Self: """ Add a geographic projection to the chart. @@ -2419,7 +2438,7 @@ def transform_bin( as_: Optional[str | FieldName | list[str | FieldName]] = Undefined, field: Optional[str | FieldName] = Undefined, bin: Literal[True] | BinParams = True, - **kwargs, + **kwargs: Any, ) -> Self: """ Add a :class:`BinTransform` to the schema. @@ -2629,7 +2648,7 @@ def transform_impute( groupby: Optional[list[str | FieldName]] = Undefined, keyvals: Optional[list[Any] | ImputeSequence] = Undefined, method: Optional[ImputeMethod_T | ImputeMethod] = Undefined, - value=Undefined, + value: Optional[Any] = Undefined, ) -> Self: """ Add an :class:`ImputeTransform` to the schema. @@ -2779,7 +2798,7 @@ def transform_filter( | Parameter | PredicateComposition | dict[str, Predicate | str | list | bool], - **kwargs, + **kwargs: Any, ) -> Self: """ Add a :class:`FilterTransform` to the schema. @@ -2918,7 +2937,7 @@ def transform_lookup( from_: Optional[LookupData | LookupSelection] = Undefined, as_: Optional[str | FieldName | list[str | FieldName]] = Undefined, default: Optional[str] = Undefined, - **kwargs, + **kwargs: Any, ) -> Self: """ Add a :class:`DataLookupTransform` or :class:`SelectionLookupTransform` to the chart. @@ -3376,7 +3395,7 @@ def transform_window( # Display-related methods - def _repr_mimebundle_(self, include=None, exclude=None): + def _repr_mimebundle_(self, *args, **kwds) -> MimeBundleType | None: # type:ignore[return] # noqa: ANN002, ANN003 """Return a MIME bundle for display in Jupyter frontends.""" # Catch errors explicitly to get around issues in Jupyter frontend # see https://github.com/ipython/ipython/issues/11038 @@ -3394,7 +3413,7 @@ def display( renderer: Optional[Literal["canvas", "svg"]] = Undefined, theme: Optional[str] = Undefined, actions: Optional[bool | dict] = Undefined, - **kwargs, + **kwargs: Any, ) -> None: """ Display chart in Jupyter notebook or JupyterLab. @@ -3436,15 +3455,15 @@ def display( @utils.deprecated(version="4.1.0", alternative="show") def serve( self, - ip="127.0.0.1", - port=8888, - n_retries=50, - files=None, - jupyter_warning=True, - open_browser=True, - http_server=None, - **kwargs, - ): + ip="127.0.0.1", # noqa: ANN001 + port=8888, # noqa: ANN001 + n_retries=50, # noqa: ANN001 + files=None, # noqa: ANN001 + jupyter_warning=True, # noqa: ANN001 + open_browser=True, # noqa: ANN001 + http_server=None, # noqa: ANN001 + **kwargs, # noqa: ANN003 + ) -> None: """'serve' is deprecated. Use 'show' instead.""" from ...utils.server import serve @@ -3476,7 +3495,7 @@ def show(self) -> None: display(self) @utils.use_signature(core.Resolve) - def _set_resolve(self, **kwargs): + def _set_resolve(self, **kwargs: Any): # noqa: ANN202 """Copy the chart and update the resolve property with kwargs.""" if not hasattr(self, "resolve"): msg = f"{self.__class__} object has no attribute " "'resolve'" @@ -3489,19 +3508,19 @@ def _set_resolve(self, **kwargs): return copy @utils.use_signature(core.AxisResolveMap) - def resolve_axis(self, *args, **kwargs) -> Self: + def resolve_axis(self, *args: Any, **kwargs: Any) -> Self: check = _top_schema_base(self) r = check._set_resolve(axis=core.AxisResolveMap(*args, **kwargs)) return t.cast("Self", r) @utils.use_signature(core.LegendResolveMap) - def resolve_legend(self, *args, **kwargs) -> Self: + def resolve_legend(self, *args: Any, **kwargs: Any) -> Self: check = _top_schema_base(self) r = check._set_resolve(legend=core.LegendResolveMap(*args, **kwargs)) return t.cast("Self", r) @utils.use_signature(core.ScaleResolveMap) - def resolve_scale(self, *args, **kwargs) -> Self: + def resolve_scale(self, *args: Any, **kwargs: Any) -> Self: check = _top_schema_base(self) r = check._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs)) return t.cast("Self", r) @@ -3517,7 +3536,7 @@ def facet( column: Optional[str | FacetFieldDef | Column] = Undefined, data: Optional[ChartDataType] = Undefined, columns: Optional[int] = Undefined, - **kwargs, + **kwargs: Any, ) -> FacetChart: """ Create a facet chart from the current chart. @@ -3644,12 +3663,12 @@ def __init__( mark: Optional[str | AnyMark] = Undefined, width: Optional[int | str | dict | Step] = Undefined, height: Optional[int | str | dict | Step] = Undefined, - **kwargs, + **kwargs: Any, ) -> None: + # Data type hints won't match with what TopLevelUnitSpec expects + # as there is some data processing happening when converting to + # a VL spec super().__init__( - # Data type hints won't match with what TopLevelUnitSpec expects - # as there is some data processing happening when converting to - # a VL spec data=data, # type: ignore[arg-type] encoding=encoding, mark=mark, @@ -3707,7 +3726,7 @@ def to_dict( format: str = "vega-lite", ignore: list[str] | None = None, context: dict[str, Any] | None = None, - ) -> dict: + ) -> dict[str, Any]: """ Convert the chart to a dictionary suitable for JSON export. @@ -3788,7 +3807,7 @@ def add_params(self, *params: Parameter) -> Self: return copy @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *params) -> Self: + def add_selection(self, *params) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*params) @@ -3823,7 +3842,7 @@ def interactive( def _check_if_valid_subspec( - spec: Optional[SchemaBase | dict], + spec: ConcatType | LayerType | dict[str, Any], classname: Literal[ "ConcatChart", "FacetChart", @@ -3855,16 +3874,16 @@ def _check_if_valid_subspec( raise ValueError(err.format(attr, classname)) -def _check_if_can_be_layered(spec: dict | SchemaBase) -> None: +def _check_if_can_be_layered(spec: LayerType | dict[str, Any]) -> None: """Check if the spec can be layered.""" - def _get(spec, attr): + def _get(spec: LayerType | dict[str, Any], attr: str) -> Any: if isinstance(spec, core.SchemaBase): return spec._get(attr) else: return spec.get(attr, Undefined) - def _get_any(spec: dict | SchemaBase, *attrs: str) -> bool: + def _get_any(spec: LayerType | dict[str, Any], *attrs: str) -> bool: return any(_get(spec, attr) is not Undefined for attr in attrs) base_msg = "charts cannot be layered. Instead, layer the charts before" @@ -3896,39 +3915,44 @@ def _get_any(spec: dict | SchemaBase, *attrs: str) -> bool: class RepeatChart(TopLevelMixin, core.TopLevelRepeatSpec): """A chart repeated across rows and columns with small changes.""" - # Because TopLevelRepeatSpec is defined as a union as of Vega-Lite schema 4.9, - # we set the arguments explicitly here. - # TODO: Should we instead use tools/schemapi/codegen.get_args? - @utils.use_signature(core.TopLevelRepeatSpec) def __init__( self, - repeat=Undefined, - spec=Undefined, - align=Undefined, - autosize=Undefined, - background=Undefined, - bounds=Undefined, - center=Undefined, - columns=Undefined, - config=Undefined, - data=Undefined, - datasets=Undefined, - description=Undefined, - name=Undefined, - padding=Undefined, - params=Undefined, - resolve=Undefined, - spacing=Undefined, - title=Undefined, - transform=Undefined, - usermeta=Undefined, - **kwds, - ): + repeat: Optional[list[str] | LayerRepeatMapping | RepeatMapping] = Undefined, + spec: Optional[ChartType] = Undefined, + align: Optional[dict | SchemaBase | LayoutAlign_T] = Undefined, + autosize: Optional[dict | SchemaBase | AutosizeType_T] = Undefined, + background: Optional[ + str | dict | Parameter | SchemaBase | ColorName_T + ] = Undefined, + bounds: Optional[Literal["full", "flush"]] = Undefined, + center: Optional[bool | dict | SchemaBase] = Undefined, + columns: Optional[int] = Undefined, + config: Optional[dict | SchemaBase] = Undefined, + data: Optional[ChartDataType] = Undefined, + datasets: Optional[dict | SchemaBase] = Undefined, + description: Optional[str] = Undefined, + name: Optional[str] = Undefined, + padding: Optional[dict | float | Parameter | SchemaBase] = Undefined, + params: Optional[Sequence[_Parameter]] = Undefined, + resolve: Optional[dict | SchemaBase] = Undefined, + spacing: Optional[dict | float | SchemaBase] = Undefined, + title: Optional[str | dict | SchemaBase | Sequence[str]] = Undefined, + transform: Optional[Sequence[dict | SchemaBase]] = Undefined, + usermeta: Optional[dict | SchemaBase] = Undefined, + **kwds: Any, + ) -> None: + tp_name = type(self).__name__ + if utils.is_undefined(spec): + msg = f"{tp_name!r} requires a `spec`, but got: {spec!r}" + raise TypeError(msg) _check_if_valid_subspec(spec, "RepeatChart") _spec_as_list = [spec] params, _spec_as_list = _combine_subchart_params(params, _spec_as_list) spec = _spec_as_list[0] if isinstance(spec, (Chart, LayerChart)): + if utils.is_undefined(repeat): + msg = f"{tp_name!r} requires a `repeat`, but got: {repeat!r}" + raise TypeError(msg) params = _repeat_names(params, repeat, spec) super().__init__( repeat=repeat, @@ -4013,7 +4037,7 @@ def add_params(self, *params: Parameter) -> Self: return copy.copy() @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) @@ -4046,22 +4070,30 @@ class ConcatChart(TopLevelMixin, core.TopLevelConcatSpec): """A chart with horizontally-concatenated facets.""" @utils.use_signature(core.TopLevelConcatSpec) - def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs): - # TODO: move common data to top level? + def __init__( + self, + data: Optional[ChartDataType] = Undefined, + concat: Sequence[ConcatType] = (), + columns: Optional[float] = Undefined, + **kwargs: Any, + ) -> None: for spec in concat: _check_if_valid_subspec(spec, "ConcatChart") - super().__init__(data=data, concat=list(concat), columns=columns, **kwargs) + super().__init__(data=data, concat=list(concat), columns=columns, **kwargs) # type: ignore[arg-type] + self.concat: list[ChartType] + self.params: Optional[Sequence[_Parameter]] + self.data: Optional[ChartDataType] self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) - def __ior__(self, other) -> Self: + def __ior__(self, other: ChartType) -> Self: _check_if_valid_subspec(other, "ConcatChart") self.concat.append(other) self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) return self - def __or__(self, other) -> Self: + def __or__(self, other: ChartType) -> Self: copy = self.copy(deep=["concat"]) copy |= other return copy @@ -4129,12 +4161,12 @@ def add_params(self, *params: Parameter) -> Self: return copy @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) -def concat(*charts, **kwargs) -> ConcatChart: +def concat(*charts: ConcatType, **kwargs: Any) -> ConcatChart: """Concatenate charts horizontally.""" return ConcatChart(concat=charts, **kwargs) # pyright: ignore @@ -4143,22 +4175,29 @@ class HConcatChart(TopLevelMixin, core.TopLevelHConcatSpec): """A chart with horizontally-concatenated facets.""" @utils.use_signature(core.TopLevelHConcatSpec) - def __init__(self, data=Undefined, hconcat=(), **kwargs): - # TODO: move common data to top level? + def __init__( + self, + data: Optional[ChartDataType] = Undefined, + hconcat: Sequence[ConcatType] = (), + **kwargs: Any, + ) -> None: for spec in hconcat: _check_if_valid_subspec(spec, "HConcatChart") - super().__init__(data=data, hconcat=list(hconcat), **kwargs) + super().__init__(data=data, hconcat=list(hconcat), **kwargs) # type: ignore[arg-type] + self.hconcat: list[ChartType] + self.params: Optional[Sequence[_Parameter]] + self.data: Optional[ChartDataType] self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) - def __ior__(self, other) -> Self: + def __ior__(self, other: ChartType) -> Self: _check_if_valid_subspec(other, "HConcatChart") self.hconcat.append(other) self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) return self - def __or__(self, other) -> Self: + def __or__(self, other: ChartType) -> Self: copy = self.copy(deep=["hconcat"]) copy |= other return copy @@ -4226,12 +4265,12 @@ def add_params(self, *params: Parameter) -> Self: return copy @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) -def hconcat(*charts, **kwargs) -> HConcatChart: +def hconcat(*charts: ConcatType, **kwargs: Any) -> HConcatChart: """Concatenate charts horizontally.""" return HConcatChart(hconcat=charts, **kwargs) # pyright: ignore @@ -4240,22 +4279,29 @@ class VConcatChart(TopLevelMixin, core.TopLevelVConcatSpec): """A chart with vertically-concatenated facets.""" @utils.use_signature(core.TopLevelVConcatSpec) - def __init__(self, data=Undefined, vconcat=(), **kwargs): - # TODO: move common data to top level? + def __init__( + self, + data: Optional[ChartDataType] = Undefined, + vconcat: Sequence[ConcatType] = (), + **kwargs: Any, + ) -> None: for spec in vconcat: _check_if_valid_subspec(spec, "VConcatChart") - super().__init__(data=data, vconcat=list(vconcat), **kwargs) + super().__init__(data=data, vconcat=list(vconcat), **kwargs) # type: ignore[arg-type] + self.vconcat: list[ChartType] + self.params: Optional[Sequence[_Parameter]] + self.data: Optional[ChartDataType] self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) - def __iand__(self, other) -> Self: + def __iand__(self, other: ChartType) -> Self: _check_if_valid_subspec(other, "VConcatChart") self.vconcat.append(other) self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) return self - def __and__(self, other) -> Self: + def __and__(self, other: ChartType) -> Self: copy = self.copy(deep=["vconcat"]) copy &= other return copy @@ -4325,12 +4371,12 @@ def add_params(self, *params: Parameter) -> Self: return copy @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) -def vconcat(*charts, **kwargs) -> VConcatChart: +def vconcat(*charts: ConcatType, **kwargs: Any) -> VConcatChart: """Concatenate charts vertically.""" return VConcatChart(vconcat=charts, **kwargs) # pyright: ignore @@ -4339,13 +4385,20 @@ class LayerChart(TopLevelMixin, _EncodingMixin, core.TopLevelLayerSpec): """A Chart with layers within a single panel.""" @utils.use_signature(core.TopLevelLayerSpec) - def __init__(self, data=Undefined, layer=(), **kwargs): - # TODO: move common data to top level? + def __init__( + self, + data: Optional[ChartDataType] = Undefined, + layer: Sequence[LayerType] = (), + **kwargs: Any, + ) -> None: # TODO: check for conflicting interaction for spec in layer: _check_if_valid_subspec(spec, "LayerChart") _check_if_can_be_layered(spec) - super().__init__(data=data, layer=list(layer), **kwargs) + super().__init__(data=data, layer=list(layer), **kwargs) # type: ignore[arg-type] + self.layer: list[ChartType] + self.params: Optional[Sequence[_Parameter]] + self.data: Optional[ChartDataType] self.data, self.layer = _combine_subchart_data(self.data, self.layer) # Currently (Vega-Lite 5.5) the same param can't occur on two layers self.layer = _remove_duplicate_params(self.layer) @@ -4385,7 +4438,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def __iadd__(self, other: LayerChart | Chart) -> Self: + def __iadd__(self, other: ChartType) -> Self: _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) self.layer.append(other) @@ -4393,7 +4446,7 @@ def __iadd__(self, other: LayerChart | Chart) -> Self: self.params, self.layer = _combine_subchart_params(self.params, self.layer) return self - def __add__(self, other: LayerChart | Chart) -> Self: + def __add__(self, other: ChartType) -> Self: copy = self.copy(deep=["layer"]) copy += other return copy @@ -4444,12 +4497,12 @@ def add_params(self, *params: Parameter) -> Self: return copy.copy() @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) -def layer(*charts, **kwargs) -> LayerChart: +def layer(*charts: LayerType, **kwargs: Any) -> LayerChart: """Layer multiple charts.""" return LayerChart(layer=charts, **kwargs) # pyright: ignore @@ -4460,17 +4513,23 @@ class FacetChart(TopLevelMixin, core.TopLevelFacetSpec): @utils.use_signature(core.TopLevelFacetSpec) def __init__( self, - data=Undefined, - spec=Undefined, - facet=Undefined, - params=Undefined, - **kwargs, - ): + data: Optional[ChartDataType] = Undefined, + spec: Optional[ChartType] = Undefined, + facet: Optional[dict | SchemaBase] = Undefined, + params: Optional[Sequence[_Parameter]] = Undefined, + **kwargs: Any, + ) -> None: + if utils.is_undefined(spec): + msg = f"{type(self).__name__!r} requires a `spec`, but got: {spec!r}" + raise TypeError(msg) _check_if_valid_subspec(spec, "FacetChart") _spec_as_list = [spec] params, _spec_as_list = _combine_subchart_params(params, _spec_as_list) spec = _spec_as_list[0] - super().__init__(data=data, spec=spec, facet=facet, params=params, **kwargs) + super().__init__(data=data, spec=spec, facet=facet, params=params, **kwargs) # type: ignore[arg-type] + self.data: Optional[ChartDataType] + self.spec: ChartType + self.params: Optional[Sequence[_Parameter]] def transformed_data( self, row_limit: int | None = None, exclude: Iterable[str] | None = None @@ -4532,12 +4591,12 @@ def add_params(self, *params: Parameter) -> Self: return copy.copy() @utils.deprecated(version="5.0.0", alternative="add_params") - def add_selection(self, *selections) -> Self: + def add_selection(self, *selections) -> Self: # noqa: ANN002 """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*selections) -def topo_feature(url: str, feature: str, **kwargs) -> UrlData: +def topo_feature(url: str, feature: str, **kwargs: Any) -> UrlData: """ A convenience function for extracting features from a topojson url. @@ -4560,8 +4619,10 @@ def topo_feature(url: str, feature: str, **kwargs) -> UrlData: ) -def _combine_subchart_data(data, subcharts): - def remove_data(subchart): +def _combine_subchart_data( + data: Optional[ChartDataType], subcharts: list[ChartType] +) -> tuple[Optional[ChartDataType], list[ChartType]]: + def remove_data(subchart: _TSchemaBase) -> _TSchemaBase: if subchart.data is not Undefined: subchart = subchart.copy() subchart.data = Undefined @@ -4585,13 +4646,18 @@ def remove_data(subchart): return data, subcharts -def _viewless_dict(param: Parameter) -> dict: +_Parameter: TypeAlias = Union[ + core.VariableParameter, core.TopLevelSelectionParameter, core.SelectionParameter +] + + +def _viewless_dict(param: _Parameter) -> dict[str, Any]: d = param.to_dict() d.pop("views", None) return d -def _needs_name(subchart): +def _needs_name(subchart: ChartType) -> bool: # Only `Chart` objects need a name if (subchart.name is not Undefined) or (not isinstance(subchart, Chart)): return False @@ -4601,7 +4667,7 @@ def _needs_name(subchart): # Convert SelectionParameters to TopLevelSelectionParameters with a views property. -def _prepare_to_lift(param: Any) -> Any: +def _prepare_to_lift(param: _Parameter) -> _Parameter: param = param.copy() if isinstance(param, core.VariableParameter): @@ -4616,15 +4682,15 @@ def _prepare_to_lift(param: Any) -> Any: return param -def _remove_duplicate_params(layer): +def _remove_duplicate_params(layer: list[ChartType]) -> list[ChartType]: subcharts = [subchart.copy() for subchart in layer] found_params = [] for subchart in subcharts: - if (not hasattr(subchart, "params")) or (subchart.params is Undefined): + if (not hasattr(subchart, "params")) or (utils.is_undefined(subchart.params)): continue - params = [] + params: list[_Parameter] = [] # Ensure the same selection parameter doesn't appear twice for param in subchart.params: @@ -4647,12 +4713,14 @@ def _remove_duplicate_params(layer): return subcharts -def _combine_subchart_params(params, subcharts): - if params is Undefined: +def _combine_subchart_params( + params: Optional[Sequence[_Parameter]], subcharts: list[ChartType] +) -> tuple[Optional[Sequence[_Parameter]], list[ChartType]]: + if utils.is_undefined(params): params = [] # List of triples related to params, (param, dictionary minus views, views) - param_info = [] + param_info: list[tuple[_Parameter, dict[str, Any], list[str]]] = [] # Put parameters already found into `param_info` list. for param in params: @@ -4668,7 +4736,7 @@ def _combine_subchart_params(params, subcharts): subcharts = [subchart.copy() for subchart in subcharts] for subchart in subcharts: - if (not hasattr(subchart, "params")) or (subchart.params is Undefined): + if (not hasattr(subchart, "params")) or (utils.is_undefined(subchart.params)): continue if _needs_name(subchart): @@ -4707,7 +4775,7 @@ def _combine_subchart_params(params, subcharts): if len(v) > 0: p.views = v - subparams = [p for p, _, _ in param_info] + subparams: Any = [p for p, _, _ in param_info] if len(subparams) == 0: subparams = Undefined @@ -4715,7 +4783,9 @@ def _combine_subchart_params(params, subcharts): return subparams, subcharts -def _get_repeat_strings(repeat): +def _get_repeat_strings( + repeat: list[str] | LayerRepeatMapping | RepeatMapping, +) -> list[str]: if isinstance(repeat, list): return repeat elif isinstance(repeat, core.LayerRepeatMapping): @@ -4727,7 +4797,7 @@ def _get_repeat_strings(repeat): return ["".join(s) for s in itertools.product(*rcstrings)] -def _extend_view_name(v, r, spec): +def _extend_view_name(v: str, r: str, spec: Chart | LayerChart) -> str: # prevent the same extension from happening more than once if isinstance(spec, Chart): if v.endswith("child__" + r): @@ -4739,14 +4809,21 @@ def _extend_view_name(v, r, spec): return v else: return f"child__{r}_{v}" + else: + msg = f"Expected 'Chart | LayerChart', but got: {type(spec).__name__!r}" + raise TypeError(msg) -def _repeat_names(params, repeat, spec): - if params is Undefined: +def _repeat_names( + params: Optional[Sequence[_Parameter]], + repeat: list[str] | LayerRepeatMapping | RepeatMapping, + spec: Chart | LayerChart, +) -> Optional[Sequence[_Parameter]]: + if utils.is_undefined(params): return params repeat = _get_repeat_strings(repeat) - params_named = [] + params_named: list[_Parameter] = [] for param in params: if not isinstance(param, core.TopLevelSelectionParameter): @@ -4773,8 +4850,10 @@ def _repeat_names(params, repeat, spec): return params_named -def _remove_layer_props(chart, subcharts, layer_props): - def remove_prop(subchart, prop): +def _remove_layer_props( + chart: LayerChart, subcharts: list[ChartType], layer_props: Iterable[str] +) -> tuple[dict[str, Any], list[ChartType]]: + def remove_prop(subchart: ChartType, prop: str) -> ChartType: # If subchart is a UnitSpec, then subchart["height"] raises a KeyError try: if subchart[prop] is not Undefined: @@ -4784,7 +4863,7 @@ def remove_prop(subchart, prop): pass return subchart - output_dict = {} + output_dict: dict[str, Any] = {} if not subcharts: # No subcharts = nothing to do. @@ -4827,7 +4906,11 @@ def remove_prop(subchart, prop): @utils.use_signature(core.SequenceParams) def sequence( - start, stop=None, step=Undefined, as_=Undefined, **kwds + start: Optional[float], + stop: Optional[float | None] = None, + step: Optional[float] = Undefined, + as_: Optional[str] = Undefined, + **kwds: Any, ) -> SequenceGenerator: """Sequence generator.""" if stop is None: @@ -4837,7 +4920,7 @@ def sequence( @utils.use_signature(core.GraticuleParams) -def graticule(**kwds) -> GraticuleGenerator: +def graticule(**kwds: Any) -> GraticuleGenerator: """Graticule generator.""" # graticule: True indicates default parameters graticule: Any = core.GraticuleParams(**kwds) if kwds else True @@ -4849,9 +4932,24 @@ def sphere() -> SphereGenerator: return core.SphereGenerator(sphere=True) -ChartType = Union[ +ChartType: TypeAlias = Union[ Chart, RepeatChart, ConcatChart, HConcatChart, VConcatChart, FacetChart, LayerChart ] +ConcatType: TypeAlias = Union[ + ChartType, + core.FacetSpec, + core.LayerSpec, + core.RepeatSpec, + core.FacetedUnitSpec, + core.LayerRepeatSpec, + core.NonNormalizedSpec, + core.NonLayerRepeatSpec, + core.ConcatSpecGenericSpec, + core.ConcatSpecGenericSpec, + core.HConcatSpecGenericSpec, + core.VConcatSpecGenericSpec, +] +LayerType: TypeAlias = Union[ChartType, core.UnitSpec, core.LayerSpec] def is_chart_type(obj: Any) -> TypeIs[ChartType]: diff --git a/pyproject.toml b/pyproject.toml index 7af04f2d3..d55b0e5ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,7 +229,9 @@ extend-safe-fixes=[ # escape-sequence-in-docstring "D301", # ends-in-period - "D400" + "D400", + # missing-return-type-special-method + "ANN204" ] # https://docs.astral.sh/ruff/preview/#using-rules-that-are-in-preview @@ -307,6 +309,8 @@ select = [ "D213", # numpy-specific-rules "NPY", + # flake8-annotations + "ANN", ] ignore = [ # Whitespace before ':' @@ -351,6 +355,8 @@ ignore = [ "D413", # doc-line-too-long "W505", + # Any as annotation + "ANN401" ] # https://docs.astral.sh/ruff/settings/#lintpydocstyle pydocstyle={ convention="numpy" } @@ -361,6 +367,10 @@ mccabe={ max-complexity=18 } # https://docs.astral.sh/ruff/settings/#lint_flake8-tidy-imports_banned-api "typing.Optional".msg = "Use `Union[T, None]` instead.\n`typing.Optional` is likely to be confused with `altair.Optional`, which have a similar but different semantic meaning.\nSee https://github.com/vega/altair/pull/3449" +[tool.ruff.lint.per-file-ignores] +# Only enforce type annotation rules on public api +"!altair/vegalite/v5/api.py" = ["ANN"] + [tool.ruff.format] quote-style = "double"