Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds ChartType type and type guard is_chart_type. Change PurePath to Path type hints #3420

Merged
merged 17 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions altair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"Categorical",
"Chart",
"ChartDataType",
"ChartType",
"Color",
"ColorDatum",
"ColorDef",
Expand Down Expand Up @@ -579,6 +580,7 @@
"expr",
"graticule",
"hconcat",
"is_chart_type",
"jupyter",
"layer",
"limit_rows",
Expand Down
26 changes: 10 additions & 16 deletions altair/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@


def write_file_or_filename(
fp: Union[str, pathlib.PurePath, IO],
fp: Union[pathlib.Path, IO],
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
content: Union[str, bytes],
mode: str = "w",
encoding: Optional[str] = None,
) -> None:
"""Write content to fp, whether fp is a string, a pathlib Path or a
file-like object"""
if isinstance(fp, str) or isinstance(fp, pathlib.PurePath):
if isinstance(fp, pathlib.Path):
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
with open(file=fp, mode=mode, encoding=encoding) as f:
f.write(content)
else:
fp.write(content)


def set_inspect_format_argument(
format: Optional[str], fp: Union[str, pathlib.PurePath, IO], inline: bool
format: Optional[str], fp: Union[pathlib.Path, IO], inline: bool
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
"""Inspect the format argument in the save function"""
if format is None:
if isinstance(fp, str):
format = fp.split(".")[-1]
elif isinstance(fp, pathlib.PurePath):
if isinstance(fp, pathlib.Path):
format = fp.suffix.lstrip(".")
else:
raise ValueError(
Expand Down Expand Up @@ -70,7 +68,7 @@ def set_inspect_mode_argument(

def save(
chart,
fp: Union[str, pathlib.PurePath, IO],
fp: Union[str, pathlib.Path, IO],
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
vega_version: Optional[str],
vegaembed_version: Optional[str],
format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None,
Expand Down Expand Up @@ -130,6 +128,8 @@ def save(
"""
if json_kwds is None:
json_kwds = {}
fp = pathlib.Path(fp) if isinstance(fp, str) else fp
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
encoding = kwargs.get("encoding", "utf-8")

format = set_inspect_format_argument(format, fp, inline) # type: ignore[assignment]

Expand All @@ -142,9 +142,7 @@ def perform_save():

if format == "json":
json_spec = json.dumps(spec, **json_kwds)
write_file_or_filename(
fp, json_spec, mode="w", encoding=kwargs.get("encoding", "utf-8")
)
write_file_or_filename(fp, json_spec, mode="w", encoding=encoding)
elif format == "html":
if inline:
kwargs["template"] = "inline"
Expand All @@ -160,12 +158,9 @@ def perform_save():
**kwargs,
)
write_file_or_filename(
fp,
mimebundle["text/html"],
mode="w",
encoding=kwargs.get("encoding", "utf-8"),
fp, mimebundle["text/html"], mode="w", encoding=encoding
)
elif format in ["png", "svg", "pdf", "vega"]:
elif format in {"png", "svg", "pdf", "vega"}:
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
mimebundle = spec_to_mimebundle(
spec=spec,
format=format,
Expand All @@ -184,7 +179,6 @@ def perform_save():
elif format == "pdf":
write_file_or_filename(fp, mimebundle["application/pdf"], mode="wb")
else:
encoding = kwargs.get("encoding", "utf-8")
write_file_or_filename(
fp, mimebundle["image/svg+xml"], mode="w", encoding=encoding
)
Expand Down
29 changes: 25 additions & 4 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from toolz.curried import pipe as _pipe
import itertools
import sys
import pathlib
import typing
from typing import cast, List, Optional, Any, Iterable, Union, Literal, IO

# Have to rename it here as else it overlaps with schema.core.Type and schema.core.Dict
Expand All @@ -29,6 +31,10 @@
from ...utils.core import DataFrameLike
from ...utils.data import DataType

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
if sys.version_info >= (3, 11):
from typing import Self
else:
Expand Down Expand Up @@ -231,9 +237,9 @@ def to_dict(self) -> TypingDict[str, Union[str, dict]]:
return {"expr": self.name}
elif self.param_type == "selection":
return {
"param": self.name.to_dict()
if hasattr(self.name, "to_dict")
else self.name
"param": (
self.name.to_dict() if hasattr(self.name, "to_dict") else self.name
)
}
else:
raise ValueError(f"Unrecognized parameter type: {self.param_type}")
Expand Down Expand Up @@ -1137,7 +1143,7 @@ def open_editor(self, *, fullscreen: bool = False) -> None:

def save(
self,
fp: Union[str, IO],
fp: Union[str, pathlib.Path, IO],
format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None,
override_data_transformer: bool = True,
scale_factor: float = 1.0,
Expand Down Expand Up @@ -4089,3 +4095,18 @@ def graticule(**kwds):
def sphere() -> core.SphereGenerator:
"""Sphere generator."""
return core.SphereGenerator(sphere=True)


ChartType = Union[
Chart, RepeatChart, ConcatChart, HConcatChart, VConcatChart, FacetChart, LayerChart
]


def is_chart_type(obj: Any) -> TypeIs[ChartType]:
"""Return `True` if the object is an Altair chart. This can be a basic chart
but also a repeat, concat, or facet chart.
"""
return isinstance(
obj,
typing.get_args(ChartType),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build-backend = "hatchling.build"
name = "altair"
authors = [{ name = "Vega-Altair Contributors" }]
dependencies = [
"typing_extensions>=4.0.1; python_version<\"3.11\"",
"typing_extensions>=4.10.0; python_version<\"3.13\"",
"jinja2",
# If you update the minimum required jsonschema version, also update it in build.yml
"jsonschema>=3.0",
Expand Down
5 changes: 5 additions & 0 deletions tools/update_init_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
cast,
)

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
if sys.version_info >= (3, 11):
from typing import Self
else:
Expand Down Expand Up @@ -97,6 +101,7 @@ def _is_relevant_attribute(attr_name: str) -> bool:
or attr is Protocol
or attr is Sequence
or attr is IO
or attr is TypeIs
or attr_name == "TypingDict"
or attr_name == "TypingGenerator"
or attr_name == "ValueOrDatum"
Expand Down
Loading