From e994fbdfacb810c7a1c664af895ff95280c14d17 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Tue, 9 Jul 2024 19:08:53 +0100 Subject: [PATCH] refactor(typing): Reuse generated `Literal` aliases in `api` (#3464) Following https://github.com/vega/altair/pull/3431 a number of these are now importable. Additionally, I spotted the `encodings` parameter was annotated with `str`, but should be restricted to the constraints of `SingleDefUnitChannel_T`. --- altair/vegalite/v5/api.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index f25491865..843315fa7 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -85,6 +85,13 @@ InlineDataset, ) from altair.expr.core import Expression, GetAttrExpression + from .schema._typing import ( + ImputeMethod_T, + SelectionType_T, + SelectionResolution_T, + SingleDefUnitChannel_T, + StackOffset_T, + ) ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]] @@ -505,9 +512,7 @@ def param( return parameter -def _selection( - type: Optional[Literal["interval", "point"]] = Undefined, **kwds -) -> Parameter: +def _selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter: # We separate out the parameter keywords from the selection keywords select_kwds = {"name", "bind", "value", "empty", "init", "views"} @@ -537,9 +542,7 @@ def _selection( message="""'selection' is deprecated. Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings.""" ) -def selection( - type: Optional[Literal["interval", "point"]] = Undefined, **kwds -) -> Parameter: +def selection(type: Optional[SelectionType_T] = Undefined, **kwds) -> Parameter: """ Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create. @@ -568,10 +571,10 @@ def selection_interval( bind: Optional[Binding | str] = Undefined, empty: Optional[bool] = Undefined, expr: Optional[str | Expr | Expression] = Undefined, - encodings: Optional[list[str]] = Undefined, + encodings: Optional[list[SingleDefUnitChannel_T]] = Undefined, on: Optional[str] = Undefined, clear: Optional[str | bool] = Undefined, - resolve: Optional[Literal["global", "union", "intersect"]] = Undefined, + resolve: Optional[SelectionResolution_T] = Undefined, mark: Optional[Mark] = Undefined, translate: Optional[str | bool] = Undefined, zoom: Optional[str | bool] = Undefined, @@ -680,11 +683,11 @@ def selection_point( bind: Optional[Binding | str] = Undefined, empty: Optional[bool] = Undefined, expr: Optional[Expr] = Undefined, - encodings: Optional[list[str]] = Undefined, + encodings: Optional[list[SingleDefUnitChannel_T]] = Undefined, fields: Optional[list[str]] = Undefined, on: Optional[str] = Undefined, clear: Optional[str | bool] = Undefined, - resolve: Optional[Literal["global", "union", "intersect"]] = Undefined, + resolve: Optional[SelectionResolution_T] = Undefined, toggle: Optional[str | bool] = Undefined, nearest: Optional[bool] = Undefined, **kwds, @@ -1853,9 +1856,7 @@ def transform_impute( frame: Optional[list[int | None]] = Undefined, groupby: Optional[list[str | FieldName]] = Undefined, keyvals: Optional[list[Any] | ImputeSequence] = Undefined, - method: Optional[ - Literal["value", "mean", "median", "max", "min"] | ImputeMethod - ] = Undefined, + method: Optional[ImputeMethod_T | ImputeMethod] = Undefined, value=Undefined, ) -> Self: """ @@ -2378,7 +2379,7 @@ def transform_stack( as_: str | FieldName | list[str], stack: str | FieldName, groupby: list[str | FieldName], - offset: Optional[Literal["zero", "center", "normalize"]] = Undefined, + offset: Optional[StackOffset_T] = Undefined, sort: Optional[list[SortField]] = Undefined, ) -> Self: """ @@ -3061,7 +3062,7 @@ def interactive( copy of self, with interactive axes added """ - encodings = [] + encodings: list[SingleDefUnitChannel_T] = [] if bind_x: encodings.append("x") if bind_y: @@ -3351,7 +3352,7 @@ def interactive( copy of self, with interactive axes added """ - encodings = [] + encodings: list[SingleDefUnitChannel_T] = [] if bind_x: encodings.append("x") if bind_y: @@ -3448,7 +3449,7 @@ def interactive( copy of self, with interactive axes added """ - encodings = [] + encodings: list[SingleDefUnitChannel_T] = [] if bind_x: encodings.append("x") if bind_y: @@ -3547,7 +3548,7 @@ def interactive( copy of self, with interactive axes added """ - encodings = [] + encodings: list[SingleDefUnitChannel_T] = [] if bind_x: encodings.append("x") if bind_y: