Skip to content

Add native prediction intervals for TFTModel #770

Merged
merged 5 commits into from
Jun 24, 2022
Merged

Add native prediction intervals for TFTModel #770

merged 5 commits into from
Jun 24, 2022

Conversation

Mr-Geekman
Copy link
Contributor

@Mr-Geekman Mr-Geekman commented Jun 21, 2022

IMPORTANT: Please do not create a Pull Request without creating an issue first.

Before submitting (must do checklist)

  • Did you read the contribution guide?
  • Did you update the docs? We use Numpy format for all the methods and classes.
  • Did you write any new necessary tests?
  • Did you update the CHANGELOG?

Type of Change

  • Examples / docs / tutorials / contributors update
  • Bug fix (non-breaking change which fixes an issue)
  • Improvement (non-breaking change which improves an existing feature)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Proposed Changes

Look #739.

Related Issue

#739.

Closing issues

Closes #739.

@Mr-Geekman Mr-Geekman self-assigned this Jun 21, 2022
@github-actions
Copy link

github-actions bot commented Jun 21, 2022

🚀 Deployed on https://deploy-preview-770--etna-docs.netlify.app

@github-actions github-actions bot temporarily deployed to pull request June 21, 2022 15:11 Inactive
@codecov-commenter
Copy link

Codecov Report

Merging #770 (d319dcc) into master (c35055e) will decrease coverage by 33.74%.
The diff coverage is 39.02%.

@@             Coverage Diff             @@
##           master     #770       +/-   ##
===========================================
- Coverage   83.54%   49.80%   -33.75%     
===========================================
  Files         122      122               
  Lines        6764     6815       +51     
===========================================
- Hits         5651     3394     -2257     
- Misses       1113     3421     +2308     
Impacted Files Coverage Δ
etna/models/nn/tft.py 72.16% <32.43%> (-27.84%) ⬇️
etna/models/nn/deepar.py 83.75% <100.00%> (-16.25%) ⬇️
etna/commands/__init__.py 0.00% <0.00%> (-100.00%) ⬇️
etna/commands/backtest_command.py 0.00% <0.00%> (-96.43%) ⬇️
etna/commands/forecast_command.py 0.00% <0.00%> (-93.94%) ⬇️
etna/commands/__main__.py 0.00% <0.00%> (-87.50%) ⬇️
etna/commands/resolvers.py 0.00% <0.00%> (-80.00%) ⬇️
etna/analysis/outliers/density_outliers.py 22.44% <0.00%> (-75.52%) ⬇️
etna/datasets/datasets_generation.py 27.02% <0.00%> (-72.98%) ⬇️
etna/transforms/timestamp/time_flags.py 27.02% <0.00%> (-72.98%) ⬇️
... and 73 more

📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more

@martins0n martins0n self-requested a review June 22, 2022 08:52
@@ -43,7 +48,9 @@ def __init__(
attention_head_size: int = 4,
dropout: float = 0.1,
hidden_continuous_size: int = 8,
loss: MultiHorizonMetric = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would raise exception for non-torch enviroments

for segment in forecast.segments:
segment_slice = forecast[:, segment, :][segment]
assert {"target_0.02", "target_0.98", "target"}.issubset(segment_slice.columns)
assert {"target_0.3"}.isdisjoint(segment_slice.columns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not "target_0.4"?

ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]

if prediction_interval:
quantiles_predicts = self.model.predict( # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if loss is not based on QuantileLoss it will not work.
May be we should check if isinstance(self.loss, QuantileLoss): this line and below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right: fitting is working but during prediction here quantiles aren't generated: only one value is returned for each point instead of len(quantiles) values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we should print Warning.

@github-actions github-actions bot temporarily deployed to pull request June 24, 2022 15:38 Inactive
Copy link
Contributor

@martins0n martins0n left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@martins0n martins0n merged commit 5487dc5 into master Jun 24, 2022
@martins0n martins0n deleted the issue-739 branch June 24, 2022 16:58
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Native prediction intervals for TFT
3 participants