Skip to content

Commit

Permalink
[Time-Series] fix past_observed_mask type (#22076)
Browse files Browse the repository at this point in the history
added > 0.5 to `past_observed_mask`
  • Loading branch information
elisim authored Apr 3, 2023
1 parent 559a45d commit 9eae4aa
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/models/informer/test_modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def prepare_informer_inputs_dict(self, config):

past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5

# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def prepare_time_series_transformer_inputs_dict(self, config):

past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5

# decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
Expand Down

0 comments on commit 9eae4aa

Please sign in to comment.