-
Notifications
You must be signed in to change notification settings - Fork 80
Add native prediction intervals for TFTModel
#770
Conversation
🚀 Deployed on https://deploy-preview-770--etna-docs.netlify.app |
Codecov Report
@@ Coverage Diff @@
## master #770 +/- ##
===========================================
- Coverage 83.54% 49.80% -33.75%
===========================================
Files 122 122
Lines 6764 6815 +51
===========================================
- Hits 5651 3394 -2257
- Misses 1113 3421 +2308
📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more |
etna/models/nn/tft.py
Outdated
@@ -43,7 +48,9 @@ def __init__( | |||
attention_head_size: int = 4, | |||
dropout: float = 0.1, | |||
hidden_continuous_size: int = 8, | |||
loss: MultiHorizonMetric = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would raise exception for non-torch enviroments
tests/test_models/nn/test_tft.py
Outdated
for segment in forecast.segments: | ||
segment_slice = forecast[:, segment, :][segment] | ||
assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns) | ||
assert {"target_0.3"}.isdisjoint(segment_slice.columns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not "target_0.4"
?
etna/models/nn/tft.py
Outdated
ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :] | ||
|
||
if prediction_interval: | ||
quantiles_predicts = self.model.predict( # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if loss is not based on QuantileLoss
it will not work.
May be we should check if isinstance(self.loss, QuantileLoss):
this line and below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right: fitting is working but during prediction here quantiles aren't generated: only one value is returned for each point instead of len(quantiles)
values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we should print Warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
IMPORTANT: Please do not create a Pull Request without creating an issue first.
Before submitting (must do checklist)
Type of Change
Proposed Changes
Look #739.
Related Issue
#739.
Closing issues
Closes #739.