Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use ax parameter in plot_series #114

Closed
MaickelHubner opened this issue Aug 5, 2024 · 5 comments · Fixed by #116
Closed

Cannot use ax parameter in plot_series #114

MaickelHubner opened this issue Aug 5, 2024 · 5 comments · Fixed by #116

Comments

@MaickelHubner
Copy link

I was instructed to directly use the utilsforecast.plotting.plot_series function instead of nixtla_client plot.

But when I run the function, it is not correctly recognizing the matplotlib ax:

def plot_df(df, forecast_df=None, period_column='period', qty_column='quantity', *args, **kwargs):
    id_col = 'unique_id'
    df = df.copy()
    if id_col not in df:
        df[id_col] = 'ts_0'
    df[period_column] = pd.to_datetime(df[period_column])
    if forecast_df is not None:
        forecasts_df = forecast_df.copy()
        if id_col not in forecasts_df:
            forecasts_df[id_col] = 'ts_0'
        forecasts_df[period_column] = pd.to_datetime(forecasts_df[period_column])
    return plot_series(
        df=df,
        forecasts_df=forecasts_df,
        palette='tab20b',
        id_col=id_col,
        time_col=period_column,
        target_col=qty_column,
        **kwargs,
    )
@pytest.mark.mpl_image_compare
def test_plot_df(demand_data, splited_data):
    sku = '10001009'
    all_train = splited_data['train']
    train = all_train[all_train.SKU == sku]
    forecast = predict(train, 'MovingAverage', level=[80])
    filtered = demand_data[demand_data.SKU == sku]
    fig, ax = plt.subplots(figsize=(16, 3.5))
    fig = plot_df(filtered, forecast, level=[80], ax=ax)
    return fig

Execution error:

pytest tests/forecast.py::test_plot_df
================================================= test session starts ==================================================
platform darwin -- Python 3.11.7, pytest-8.3.1, pluggy-1.5.0
Matplotlib: 3.8.4
Freetype: 2.6.1
Fugue tests will be initialized with options:
rootdir: /Users/hubner/Repos/Trabalho/PartsCloud/inventory-planning
configfile: pytest.ini
plugins: anyio-4.3.0, mpl-0.17.0, fugue-0.9.0
collected 1 item                                                                                                       

tests/forecast.py F                                                                                              [100%]

======================================================= FAILURES =======================================================
_____________________________________________________ test_plot_df _____________________________________________________

args = ()
kwargs = {'demand_data':         period       SKU  quantity
0   2020-06-01  10001009       0.0
1   2020-06-01  10001010       0...-01  10001346      10.0
118 2023-12-01  10001368      85.0
119 2023-12-01  10001481       3.0

[120 rows x 3 columns]}}

    def wrapper(*args, **kwargs):
>       store.return_value[test_name] = obj(*args, **kwargs)

.venv/lib/python3.11/site-packages/pytest_mpl/plugin.py:125: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/forecast.py:37: in test_plot_df
    fig = plot_df(filtered, forecast, level=[80], ax=ax)
engine/forecast.py:237: in plot_df
    return plot_series(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

df =         period       SKU  quantity unique_id
0   2020-06-01  10001009       0.0      ts_0
10  2020-07-01  10001009    ... 10001009       3.0      ts_0
410 2023-11-01  10001009       4.0      ts_0
420 2023-12-01  10001009       0.0      ts_0
forecasts_df =        period   predict  predict-lo-80  predict-hi-80 unique_id
0  2023-01-01  0.333333      -0.405675       1.072342 ...11-01  0.500556      -0.238452       1.239564      ts_0
11 2023-12-01  0.500310      -0.238699       1.239318      ts_0
ids = None, plot_random = True, max_ids = 8, models = ['predict'], level = [80], max_insample_length = None
plot_anomalies = False, engine = 'matplotlib', palette = 'tab20b', id_col = 'unique_id', time_col = 'period'
target_col = 'quantity', seed = 0, resampler_kwargs = None, ax = <Axes: >

    def plot_series(
        df: Optional[DataFrame] = None,
        forecasts_df: Optional[DataFrame] = None,
        ids: Optional[List[str]] = None,
        plot_random: bool = True,
        max_ids: int = 8,
        models: Optional[List[str]] = None,
        level: Optional[List[float]] = None,
        max_insample_length: Optional[int] = None,
        plot_anomalies: bool = False,
        engine: str = "matplotlib",
        palette: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        seed: int = 0,
        resampler_kwargs: Optional[Dict] = None,
        ax: Optional[Union[plt.Axes, "plotly.graph_objects.Figure"]] = None,
    ):
        """Plot forecasts and insample values.
    
        Parameters
        ----------
        df : pandas or polars DataFrame, optional (default=None)
            DataFrame with columns [`id_col`, `time_col`, `target_col`].
        forecasts_df : pandas or polars DataFrame, optional (default=None)
            DataFrame with columns [`id_col`, `time_col`] and models.
        ids : list of str, optional (default=None)
            Time Series to plot.
            If None, time series are selected randomly.
        plot_random : bool (default=True)
            Select time series to plot randomly.
        max_ids : int (default=8)
            Maximum number of ids to plot.
        models : list of str, optional (default=None)
            Models to plot.
        level : list of float, optional (default=None)
            Prediction intervals to plot.
        max_insample_length : int, optional (default=None)
            Maximum number of train/insample observations to be plotted.
        plot_anomalies : bool (default=False)
            Plot anomalies for each prediction interval.
        engine : str (default='matplotlib')
            Library used to plot. 'plotly', 'plotly-resampler' or 'matplotlib'.
        palette : str (default=None)
            Name of the matplotlib colormap to use for the plots. If None, uses the current style.
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        seed : int (default=0)
            Seed used for the random number generator. Only used if plot_random is True.
        resampler_kwargs : dict
            Keyword arguments to be passed to plotly-resampler constructor.
            For further custumization ("show_dash") call the method,
            store the plotting object and add the extra arguments to
            its `show_dash` method.
        ax : matplotlib axes or plotly Figure, optional (default=None)
            Object where plots will be added.
    
        Returns
        -------
        fig : matplotlib or plotly figure
            Plot's figure
        """
        # checks
        supported_engines = ["matplotlib", "plotly", "plotly-resampler"]
        if engine not in supported_engines:
            raise ValueError(f"engine must be one of {supported_engines}, got '{engine}'.")
        if engine.startswith("plotly"):
            try:
                import plotly.graph_objects as go
                from plotly.subplots import make_subplots
            except ImportError:
                raise ImportError(
                    "plotly is not installed. Please install it and try again.\n"
                    "You can find detailed instructions at https://github.com/plotly/plotly.py#installation"
                )
        if plot_anomalies:
            if level is None:
                raise ValueError(
                    "In order to plot anomalies you have to specify the `level` argument"
                )
            elif forecasts_df is None or not any("lo" in c for c in forecasts_df.columns):
                raise ValueError(
                    "In order to plot anomalies you have to provide a `forecasts_df` with prediction intervals."
                )
        if level is not None and not isinstance(level, list):
            raise ValueError(
                "Please use a list for the `level` argument "
                "If you only have one level, use `level=[your_level]`"
            )
        elif level is None:
            level = []
        if df is None and forecasts_df is None:
            raise ValueError("At least one of `df` and `forecasts_df` must be provided.")
        elif df is not None:
            validate_format(df, id_col, time_col, target_col)
        elif forecasts_df is not None:
            validate_format(forecasts_df, id_col, time_col, None)
    
        # models to plot
        if models is None:
            if forecasts_df is None:
                models = []
            else:
                models = [
                    c
                    for c in forecasts_df.columns
                    if c not in [id_col, time_col, target_col]
                    and not re.search(r"-(?:lo|hi)-\d+", c)
                ]
    
        # ids
        if ids is None:
            if df is not None:
                uids: Union[np.ndarray, pl_Series, List] = df[id_col].unique()
            else:
                assert forecasts_df is not None
                uids = forecasts_df[id_col].unique()
        else:
            uids = ids
        if ax is not None:
            if isinstance(ax, np.ndarray) and isinstance(ax.flat[0], plt.Axes):
                gs = ax.flat[0].get_gridspec()
                n_rows, n_cols = gs.nrows, gs.ncols
                ax = ax.reshape(n_rows, n_cols)
            elif engine.startswith("plotly") and isinstance(ax, go.Figure):
                rows, cols = ax._get_subplot_rows_columns()
                # rows and cols are ranges
                n_rows = len(rows)
                n_cols = len(cols)
            else:
>               raise ValueError(f"Cannot process `ax` of type: {type(ax).__name__}.")
E               ValueError: Cannot process `ax` of type: Axes.

.venv/lib/python3.11/site-packages/utilsforecast/plotting.py:183: ValueError
=================================================== warnings summary ===================================================
tests/forecast.py::test_plot_df
  /Users/hubner/Repos/Trabalho/PartsCloud/inventory-planning/.venv/lib/python3.11/site-packages/pytest_mpl/plugin.py:322: MatplotlibDeprecationWarning: Auto-close()ing of figures upon backend switching is deprecated since 3.8 and will be removed two minor releases later.  To suppress this warning, explicitly call plt.close('all') first.
    plt.switch_backend(prev_backend)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================== short test summary info ================================================
FAILED tests/forecast.py::test_plot_df - ValueError: Cannot process `ax` of type: Axes.
============================================= 1 failed, 1 warning in 0.28s =============================================

@MaickelHubner
Copy link
Author

If possible, add the ax parameter to the nixtla_client plot function too. This will avoid a lot of unnecessary code.
Thanks

@jmoralez
Copy link
Member

jmoralez commented Aug 5, 2024

Can you set squeeze=False in the subplots call?

@MaickelHubner
Copy link
Author

It works with this parameter 👍
Could you add the ax parameter to the nixtla_client plot function, too?
Thanks

@jmoralez
Copy link
Member

jmoralez commented Aug 5, 2024

We'll add it but the releases in that library take a lot longer than here.

@MaickelHubner
Copy link
Author

Ok. I'll keep an eye out.
Thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants