Skip to content

Commit

Permalink
rephrased checking if valid_date is too close to the edge without mes…
Browse files Browse the repository at this point in the history
…sing with augment parameter
  • Loading branch information
Christina Butsko committed Sep 11, 2024
1 parent 8d7e4c1 commit b2f1aa3
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]
available_timesteps = int(row_d["available_timesteps"])
valid_position = int(row_d["valid_position"])

# force moving the center point if it is too close to the edges
if (valid_position < cls.NUM_TIMESTEPS // 2) or (
valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2)
):
augment = True

if not augment:
# Center the timesteps around the valid position
center_point = valid_position
# check if the valid position is too close to the start_date and force shifting it
if valid_position < cls.NUM_TIMESTEPS // 2:
center_point = cls.NUM_TIMESTEPS // 2
# or too close to the end_date
elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2):
center_point = available_timesteps - cls.NUM_TIMESTEPS // 2
else:
# Center the timesteps around the valid position
center_point = valid_position
else:
# Shift the center point but make sure the resulting range
# well includes the valid position
Expand Down

0 comments on commit b2f1aa3

Please sign in to comment.