Skip to content

Commit

Permalink
fix bug in trendline in the case of missing values (#2357)
Browse files Browse the repository at this point in the history
* fix bug in trendline in the case of missing values

* paint it black

* added statsmodels to dependencies for CI

* version for py2

* Update packages/python/plotly/plotly/express/_core.py

Co-Authored-By: Nicolas Kruchten <[email protected]>

* extended test to lowess, and more precise check of attribute length

* Update packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py

Co-authored-by: Nicolas Kruchten <[email protected]>
  • Loading branch information
emmanuelle and nicolaskruchten authored Apr 27, 2020
1 parent ad0dd30 commit e778c6b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .circleci/create_conda_optional_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ if [ ! -d $HOME/miniconda/envs/circle_optional ]; then
# Create environment
# PYTHON_VERSION=2.7 or 3.5
$HOME/miniconda/bin/conda create -n circle_optional --yes python=$PYTHON_VERSION \
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets
requests nbformat six retrying psutil pandas decorator pytest mock nose poppler xarray scikit-image ipython jupyter ipykernel ipywidgets statsmodels

# Install orca into environment
$HOME/miniconda/bin/conda install --yes -n circle_optional -c plotly plotly-orca==1.3.1
Expand Down
13 changes: 10 additions & 3 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,25 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
sorted_trace_data = trace_data.sort_values(by=args["x"])
y = sorted_trace_data[args["y"]]
x = sorted_trace_data[args["x"]]
trace_patch["x"] = x

if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds

if attr_value == "lowess":
trendline = sm.nonparametric.lowess(y, x)
# missing ='drop' is the default value for lowess but not for OLS (None)
# we force it here in case statsmodels change their defaults
trendline = sm.nonparametric.lowess(y, x, missing="drop")
trace_patch["x"] = trendline[:, 0]
trace_patch["y"] = trendline[:, 1]
hover_header = "<b>LOWESS trendline</b><br><br>"
elif attr_value == "ols":
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
fit_results = sm.OLS(
y.values, sm.add_constant(x.values), missing="drop"
).fit()
trace_patch["y"] = fit_results.predict()
trace_patch["x"] = x[
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
]
hover_header = "<b>OLS trendline</b><br>"
hover_header += "%s = %g * %s + %g<br>" % (
args["y"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import plotly.express as px
import numpy as np


def test_trendline_nan_values():
df = px.data.gapminder().query("continent == 'Oceania'")
start_date = 1970
df["pop"][df["year"] < start_date] = np.nan
modes = ["ols", "lowess"]
for mode in modes:
fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
for trendline in fig["data"][1::2]:
assert trendline.x[0] >= start_date
assert len(trendline.x) == len(trendline.y)
1 change: 1 addition & 0 deletions packages/python/plotly/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ deps=
pytest==3.5.1
pandas==0.24.2
xarray==0.10.9
statsmodels==0.10.2
backports.tempfile==1.0
optional: --editable=file:///{toxinidir}/../plotly-geo
optional: numpy==1.16.5
Expand Down

0 comments on commit e778c6b

Please sign in to comment.