Skip to content

Commit

Permalink
bug fix: when dask import, use dask.array.where instead of np.where
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Jan 3, 2024
1 parent a8ac39b commit de9558e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions lib/eofs/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def __init__(self, dataset, weights=None, center=True, ddof=1):
if not self._valid_nan(self._data):
raise ValueError('missing values detected in different '
'locations at different times')
nonMissingIndex = np.where(np.logical_not(np.isnan(self._data[0])))[0]
if has_dask:
nonMissingIndex = dask.array.where(np.logical_not(np.isnan(self._data[0])))[0]
else:
nonMissingIndex = np.where(np.logical_not(np.isnan(self._data[0])))[0]
# Remove missing values from the design matrix.
dataNoMissing = self._data[:, nonMissingIndex]
if dataNoMissing.size == 0:
Expand Down Expand Up @@ -741,7 +744,10 @@ def projectField(self, field, neofs=None, eofscaling=0, weighted=True):
if not self._valid_nan(field_flat):
raise ValueError('missing values detected in different '
'locations at different times')
nonMissingIndex = np.where(np.logical_not(np.isnan(field_flat[0])))[0]
if has_dask:
nonMissingIndex = dask.array.where(np.logical_not(np.isnan(self._data[0])))[0] # lee1043 testing
else:
nonMissingIndex = np.where(np.logical_not(np.isnan(field_flat[0])))[0]
try:
# Compute chunk sizes if nonMissingIndex is a dask array, so its
# shape can be compared with eofsNonMissingIndex later.
Expand Down

0 comments on commit de9558e

Please sign in to comment.