Skip to content

Commit

Permalink
Make explicit cast to np.float64 in nns
Browse files Browse the repository at this point in the history
  • Loading branch information
d.a.bunin committed Jan 19, 2023
1 parent bfbd4b1 commit 38385ef
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + prediction_size)
for (segment, feature_nm), value in predictions.items():
# we don't want to change dtype after assignment, but there can happen cast to float32
dtype = future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]]
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:prediction_size, :].astype(dtype)
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:prediction_size, :].astype(np.float64)

future_ts.inverse_transform()

Expand Down

0 comments on commit 38385ef

Please sign in to comment.