Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in trendline in the case of missing values #2357

Merged
merged 8 commits into from
Apr 27, 2020
2 changes: 1 addition & 1 deletion .circleci/create_conda_optional_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ if [ ! -d $HOME/miniconda/envs/circle_optional ]; then
# Create environment
# PYTHON_VERSION=2.7 or 3.5
$HOME/miniconda/bin/conda create -n circle_optional --yes python=$PYTHON_VERSION \
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets statsmodels

# Install orca into environment
$HOME/miniconda/bin/conda install --yes -n circle_optional -c plotly plotly-orca==1.3.1
Expand Down
13 changes: 10 additions & 3 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,25 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
sorted_trace_data = trace_data.sort_values(by=args["x"])
y = sorted_trace_data[args["y"]]
x = sorted_trace_data[args["x"]]
trace_patch["x"] = x

if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds

if attr_value == "lowess":
trendline = sm.nonparametric.lowess(y, x)
# missing ='drop' is the default value for lowess but not for OLS (None)
# we force it here in case statsmodels change their defaults
trendline = sm.nonparametric.lowess(y, x, missing="drop")
trace_patch["x"] = trendline[:, 0]
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
trace_patch["y"] = trendline[:, 1]
hover_header = "<b>LOWESS trendline</b><br><br>"
elif attr_value == "ols":
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
fit_results = sm.OLS(
y.values, sm.add_constant(x.values), missing="drop"
).fit()
trace_patch["y"] = fit_results.predict()
trace_patch["x"] = x[
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
]
hover_header = "<b>OLS trendline</b><br>"
hover_header += "%s = %g * %s + %g<br>" % (
args["y"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import plotly.express as px
import numpy as np


def test_trendline_nan_values():
nicolaskruchten marked this conversation as resolved.
Show resolved Hide resolved
df = px.data.gapminder().query("continent == 'Oceania'")
start_date = 1970
df["pop"][df["year"] < start_date] = np.nan
modes = ["ols", "lowess"]
for mode in modes:
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
for trendline in fig["data"][1::2]:
assert trendline.x[0] >= start_date
assert len(trendline.x) == len(trendline.y)
1 change: 1 addition & 0 deletions packages/python/plotly/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ deps=
pytest==3.5.1
pandas==0.24.2
xarray==0.10.9
statsmodels==0.10.2
backports.tempfile==1.0
optional: --editable=file:///{toxinidir}/../plotly-geo
optional: numpy==1.16.5
Expand Down