Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Fix issues with Chart|LayerChart.encode, 1.32x speedup to infer_encoding_types #3444

Merged
merged 13 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 140 additions & 63 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@
Callable,
TypeVar,
Any,
Sequence,
Iterator,
cast,
Literal,
Protocol,
TYPE_CHECKING,
runtime_checkable,
)
from itertools import groupby
from operator import itemgetter

import jsonschema
import pandas as pd
import numpy as np
from pandas.api.types import infer_dtype

from altair.utils.schemapi import SchemaBase
from altair.utils.schemapi import SchemaBase, Undefined
from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -773,9 +775,133 @@ def display_traceback(in_ipython: bool = True):
traceback.print_exception(*exc_info)


_ChannelType = Literal["field", "datum", "value"]
_CHANNEL_CACHE: _ChannelCache
"""Singleton `_ChannelCache` instance.

Initialized on first use.
"""


class _ChannelCache:
channel_to_name: dict[type[SchemaBase], str]
name_to_channel: dict[str, dict[_ChannelType, type[SchemaBase]]]

@classmethod
def from_channels(cls, channels: ModuleType, /) -> _ChannelCache:
# - This branch is only kept for tests that depend on mocking `channels`.
# - No longer needs to pass around `channels` reference and rebuild every call.
c_to_n = {
c: c._encoding_name
for c in channels.__dict__.values()
if isinstance(c, type)
and issubclass(c, SchemaBase)
and hasattr(c, "_encoding_name")
}
self = cls.__new__(cls)
self.channel_to_name = c_to_n
self.name_to_channel = _invert_group_channels(c_to_n)
return self

@classmethod
def from_cache(cls) -> _ChannelCache:
global _CHANNEL_CACHE
try:
cached = _CHANNEL_CACHE
except NameError:
cached = cls.__new__(cls)
cached.channel_to_name = _init_channel_to_name()
cached.name_to_channel = _invert_group_channels(cached.channel_to_name)
_CHANNEL_CACHE = cached
return _CHANNEL_CACHE

def get_encoding(self, tp: type[Any], /) -> str:
if encoding := self.channel_to_name.get(tp):
return encoding
msg = f"positional of type {type(tp).__name__!r}"
raise NotImplementedError(msg)

def _wrap_in_channel(self, obj: Any, encoding: str, /):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the first time I see a forward slash, / as argument within a function. Can you explain what that does?

Copy link
Member Author

@dangotbanned dangotbanned Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @mattijn

/ is a marker for positional-only parameters. Any parameter to the left of / is positional-only, and may not be used by name.
In this case, calling self._wrap_in_channel(encoding="enc", obj=[1,2,3]) raises a TypeError. The python tutorial docs may be helpful for an overview as well.

Personally, I like to use / in cases like:

  • The function/method is not part of the public API, to leave more flexibility in renaming parameters in the future, without introducing a breaking change
  • There are a 1-3 parameters
  • The function is currently used in one specific way and the function name and parameter order is clear
    • In this case, _wrap_in_channel could logically be thought of as having parameters _wrap_in_channel(wrappee, wrapper).

Looking back at PEP570, I see that this was introduced in python3.8, which may explain the feature's absence in altair until now.

Hope all of that was helpful

if isinstance(obj, SchemaBase):
return obj
elif isinstance(obj, str):
obj = {"shorthand": obj}
elif isinstance(obj, (list, tuple)):
return [self._wrap_in_channel(el, encoding) for el in obj]
if channel := self.name_to_channel.get(encoding):
tp = channel["value" if "value" in obj else "field"]
try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return tp.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj
else:
warnings.warn(f"Unrecognized encoding channel {encoding!r}", stacklevel=1)
return obj

def infer_encoding_types(self, kwargs: dict[str, Any], /):
return {
encoding: self._wrap_in_channel(obj, encoding)
for encoding, obj in kwargs.items()
if obj is not Undefined
}


def _init_channel_to_name():
"""
Construct a dictionary of channel type to encoding name.

Note
----
The return type is not expressible using annotations, but is used
internally by `mypy`/`pyright` and avoids the need for type ignores.

Returns
-------
mapping: dict[type[`<subclass of FieldChannelMixin and SchemaBase>`] | type[`<subclass of ValueChannelMixin and SchemaBase>`] | type[`<subclass of DatumChannelMixin and SchemaBase>`], str]
"""
from altair.vegalite.v5.schema import channels as ch

mixins = ch.FieldChannelMixin, ch.ValueChannelMixin, ch.DatumChannelMixin

return {
binste marked this conversation as resolved.
Show resolved Hide resolved
c: c._encoding_name
for c in ch.__dict__.values()
if isinstance(c, type) and issubclass(c, mixins) and issubclass(c, SchemaBase)
}


def _invert_group_channels(
m: dict[type[SchemaBase], str], /
) -> dict[str, dict[_ChannelType, type[SchemaBase]]]:
"""Grouped inverted index for `_ChannelCache.channel_to_name`."""

def _reduce(it: Iterator[tuple[type[Any], str]]) -> Any:
"""Returns a 1-2 item dict, per channel.

Never includes `datum`, as it is never utilized in `wrap_in_channel`.
"""
item: dict[Any, type[SchemaBase]] = {}
for tp, _ in it:
name = tp.__name__
if name.endswith("Datum"):
continue
elif name.endswith("Value"):
sub_key = "value"
else:
sub_key = "field"
item[sub_key] = tp
return item

grouper = groupby(m.items(), itemgetter(1))
return {k: _reduce(chans) for k, chans in grouper}


def infer_encoding_types(
args: Sequence[Any], kwargs: t.MutableMapping[str, Any], channels: ModuleType
) -> dict[str, SchemaBase | list | dict[str, str] | Any]:
args: tuple[Any, ...], kwargs: dict[str, Any], channels: ModuleType | None = None
):
"""Infer typed keyword arguments for args and kwargs

Parameters
Expand All @@ -793,68 +919,19 @@ def infer_encoding_types(
All args and kwargs in a single dict, with keys and types
based on the channels mapping.
"""
# Construct a dictionary of channel type to encoding name
# TODO: cache this somehow?
channel_objs = (getattr(channels, name) for name in dir(channels))
channel_objs = (
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase)
cache = (
binste marked this conversation as resolved.
Show resolved Hide resolved
_ChannelCache.from_channels(channels)
if channels
else _ChannelCache.from_cache()
)
channel_to_name: dict[type[SchemaBase], str] = {
c: c._encoding_name for c in channel_objs
}
name_to_channel: dict[str, dict[str, type[SchemaBase]]] = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
if chan.__name__.endswith("Datum"):
key = "datum"
elif chan.__name__.endswith("Value"):
key = "value"
else:
key = "field"
chans[key] = chan

# First use the mapping to convert args to kwargs based on their types.
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0:
type_ = type(arg[0])
el = next(iter(arg), None) if isinstance(arg, (list, tuple)) else arg
encoding = cache.get_encoding(type(el))
if encoding not in kwargs:
kwargs[encoding] = arg
else:
type_ = type(arg)

encoding = channel_to_name.get(type_)
if encoding is None:
msg = f"positional of type {type_}" ""
raise NotImplementedError(msg)
if encoding in kwargs:
msg = f"encoding {encoding} specified twice."
msg = f"encoding {encoding!r} specified twice."
raise ValueError(msg)
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, encoding):
if isinstance(obj, SchemaBase):
return obj

if isinstance(obj, str):
obj = {"shorthand": obj}

if isinstance(obj, (list, tuple)):
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]

if encoding not in name_to_channel:
warnings.warn(f"Unrecognized encoding channel '{encoding}'", stacklevel=1)
return obj

classes = name_to_channel[encoding]
cls = classes["value"] if "value" in obj else classes["field"]

try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return cls.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj

return {
encoding: _wrap_in_channel_class(obj, encoding)
for encoding, obj in kwargs.items()
}
return cache.infer_encoding_types(kwargs)
76 changes: 43 additions & 33 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
from typing import Union, cast, Any, Iterable, Literal, IO, TYPE_CHECKING
from typing_extensions import TypeAlias
import typing

from .schema import core, channels, mixins, Undefined, SCHEMA_URL

Expand Down Expand Up @@ -74,8 +75,6 @@
Step,
RepeatRef,
NonNormalizedSpec,
LayerSpec,
UnitSpec,
UrlData,
SequenceGenerator,
GraticuleGenerator,
Expand Down Expand Up @@ -384,6 +383,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
return False


_TestPredicateType = Union[str, _expr_core.Expression, core.PredicateComposition]
_PredicateType = Union[
Parameter,
core.Expr,
typing.Dict[str, Any],
_TestPredicateType,
_expr_core.OperatorMixin,
]
_ConditionType = typing.Dict[str, Union[_TestPredicateType, Any]]
_DictOrStr = Union[typing.Dict[str, Any], str]
_DictOrSchema = Union[core.SchemaBase, typing.Dict[str, Any]]
_StatementType = Union[core.SchemaBase, _DictOrStr]

# ------------------------------------------------------------------------
# Top-Level Functions

Expand Down Expand Up @@ -829,18 +841,33 @@ def binding_range(**kwargs):
return core.BindRange(input="range", **kwargs)


_TSchemaBase = typing.TypeVar("_TSchemaBase", bound=core.SchemaBase)


@typing.overload
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
def condition(
predicate: _PredicateType, if_true: _StatementType, if_false: _TSchemaBase, **kwargs
) -> _TSchemaBase: ...
@typing.overload
def condition(
predicate: _PredicateType, if_true: str, if_false: str, **kwargs
) -> typing.NoReturn: ...
@typing.overload
def condition(
predicate: _PredicateType, if_true: _DictOrSchema, if_false: _DictOrStr, **kwargs
) -> dict[str, _ConditionType | Any]: ...
@typing.overload
def condition(
predicate: _PredicateType,
if_true: _DictOrStr,
if_false: dict[str, Any],
**kwargs,
) -> dict[str, _ConditionType | Any]: ...
# TODO: update the docstring
def condition(
predicate: Parameter
| str
| Expression
| Expr
| PredicateComposition
| dict[str, Any],
# Types of these depends on where the condition is used so we probably
# can't be more specific here.
if_true: Any,
if_false: Any,
predicate: _PredicateType,
if_true: _StatementType,
if_false: _StatementType,
**kwargs,
) -> dict[str, Any] | SchemaBase:
"""A conditional attribute or encoding
Expand Down Expand Up @@ -2732,24 +2759,7 @@ def resolve_scale(self, *args, **kwargs) -> Self:
return self._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs))


class _EncodingMixin:
@utils.use_signature(channels._encode_signature)
def encode(self, *args, **kwargs) -> Self:
# Convert args to kwargs based on their types.
kwargs = utils.infer_encoding_types(args, kwargs, channels)

# get a copy of the dict representation of the previous encoding
# ignore type as copy method comes from SchemaBase
copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined]
encoding = copy._get("encoding", {})
if isinstance(encoding, core.VegaLiteSchema):
encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined}

# update with the new encodings, and apply them to the copy
encoding.update(kwargs)
copy.encoding = core.FacetedEncoding(**encoding)
return copy

class _EncodingMixin(channels._EncodingMixin):
def facet(
self,
facet: Optional[str | Facet] = Undefined,
Expand Down Expand Up @@ -3617,20 +3627,20 @@ def transformed_data(

return transformed_data(self, row_limit=row_limit, exclude=exclude)

def __iadd__(self, other: LayerSpec | UnitSpec) -> Self:
def __iadd__(self, other: LayerChart | Chart) -> Self:
_check_if_valid_subspec(other, "LayerChart")
_check_if_can_be_layered(other)
self.layer.append(other)
self.data, self.layer = _combine_subchart_data(self.data, self.layer)
self.params, self.layer = _combine_subchart_params(self.params, self.layer)
return self

def __add__(self, other: LayerSpec | UnitSpec) -> Self:
def __add__(self, other: LayerChart | Chart) -> Self:
copy = self.copy(deep=["layer"])
copy += other
return copy

def add_layers(self, *layers: LayerSpec | UnitSpec) -> Self:
def add_layers(self, *layers: LayerChart | Chart) -> Self:
copy = self.copy(deep=["layer"])
for layer in layers:
copy += layer
Expand Down
Loading