-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using new parquet in train #104
base: main
Are you sure you want to change the base?
Changes from 5 commits
523014f
5c74855
cb197a3
9e9ac9d
b3e0284
1d37536
ae27e25
8d7e4c1
b2f1aa3
ff25509
e338c09
c9ffa01
f6e1a9e
407cec9
ad05b3d
2945e9d
ef06f94
30ab19c
96dbc0d
55dbbbe
a7eedd8
6f7646f
203c4ac
465d65a
54cc2be
e44445a
1a957a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
DynamicWorld2020_2021, | ||
) | ||
from .masking import BAND_EXPANSION, MaskedExample, MaskParamsNoDw | ||
from .utils import DEFAULT_SEED, data_dir, load_world_df | ||
from .utils import DEFAULT_SEED, MIN_EDGE_BUFFER, data_dir, load_world_df | ||
|
||
logger = logging.getLogger("__main__") | ||
|
||
|
@@ -65,23 +65,73 @@ def target_crop(row_d: Dict) -> int: | |
# by default, we predict crop vs non crop | ||
return int(row_d["LANDCOVER_LABEL"] == 11) | ||
|
||
@classmethod | ||
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]: | ||
available_timesteps = int(row_d["available_timesteps"]) | ||
valid_position = int(row_d["valid_position"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment describing what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see its created in |
||
|
||
# force moving the center point if it is too close to the edges | ||
if (valid_position < cls.NUM_TIMESTEPS // 2) or ( | ||
valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2) | ||
): | ||
augment = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case we enter unconsciously in augmentation which might not be what we desire. Can't we put this logic inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. tried to rewrite it here b2f1aa3 |
||
|
||
if not augment: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also make a note that augment here means temporal jittering? Am I right in understanding |
||
# Center the timesteps around the valid position | ||
center_point = valid_position | ||
else: | ||
# Shift the center point but make sure the resulting range | ||
# well includes the valid position | ||
|
||
min_center_point = max( | ||
cls.NUM_TIMESTEPS // 2, valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2 | ||
) | ||
max_center_point = min( | ||
available_timesteps - cls.NUM_TIMESTEPS // 2, | ||
valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2, | ||
) | ||
|
||
center_point = np.random.randint( | ||
min_center_point, max_center_point + 1 | ||
) # max_center_point included | ||
|
||
last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2) | ||
first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS) | ||
timestep_positions = list(range(first_timestep, last_timestep)) | ||
|
||
if len(timestep_positions) != cls.NUM_TIMESTEPS: | ||
raise ValueError( | ||
f"Acquired timestep positions do not have correct length: \ | ||
required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}" | ||
) | ||
assert ( | ||
valid_position in timestep_positions | ||
), f"Valid position {valid_position} not in timestep positions {timestep_positions}" | ||
return timestep_positions | ||
|
||
@classmethod | ||
def row_to_arrays( | ||
cls, row: pd.Series, target_function: Callable[[Dict], int] | ||
cls, row: pd.Series, target_function: Callable[[Dict], int], augment: bool = False | ||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]: | ||
# https://stackoverflow.com/questions/45783891/is-there-a-way-to-speed-up-the-pandas-getitem-getitem-axis-and-get-label | ||
# This is faster than indexing the series every time! | ||
row_d = pd.Series.to_dict(row) | ||
|
||
latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32) | ||
month = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month - 1 | ||
|
||
timestep_positions = cls.get_timestep_positions(row_d, augment=augment) | ||
# make sure that month for encoding gets shifted according to | ||
# the selected timestep positions | ||
month = ( | ||
pd.to_datetime(row_d["start_date"]) + pd.DateOffset(months=timestep_positions[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could also just add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. tried to account for that here ad05b3d There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean this has to be tackled in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, maybe I'm overthinking it... here's an example:
am I missing something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following most of it. But still, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, maybe something like this can work:
it can be a little imprecise in some cases, not more than 1 month error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah let's make a note that Giorgia can test this out properly when doing dekadal runs. |
||
).month - 1 | ||
|
||
eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS))) | ||
# an assumption we make here is that all timesteps for a token | ||
# have the same masking | ||
mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX))) | ||
for df_val, presto_val in cls.BAND_MAPPING.items(): | ||
values = np.array([float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)]) | ||
values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions]) | ||
# this occurs for the DEM values in one point in Fiji | ||
values = np.nan_to_num(values, nan=cls._NODATAVALUE) | ||
idx_valid = values != cls._NODATAVALUE | ||
|
@@ -260,6 +310,7 @@ def __init__( | |
years_to_remove: Optional[List[int]] = None, | ||
target_function: Optional[Callable[[Dict], int]] = None, | ||
balance: bool = False, | ||
augment: bool = False, | ||
): | ||
dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)] | ||
|
||
|
@@ -275,6 +326,7 @@ def __init__( | |
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))] | ||
self.target_function = target_function if target_function is not None else self.target_crop | ||
self._class_weights: Optional[np.ndarray] = None | ||
self.augment = augment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. addressed that here ae27e25 |
||
|
||
super().__init__(dataframe) | ||
if balance: | ||
|
@@ -307,7 +359,9 @@ def __getitem__(self, idx): | |
# Get the sample | ||
df_index = self.indices[idx] | ||
row = self.df.iloc[df_index, :] | ||
eo, mask_per_token, latlon, month, target = self.row_to_arrays(row, self.target_function) | ||
eo, mask_per_token, latlon, month, target = self.row_to_arrays( | ||
row, self.target_function, self.augment | ||
) | ||
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) | ||
return ( | ||
self.normalize_and_mask(eo), | ||
|
@@ -354,7 +408,9 @@ def __getitem__(self, idx): | |
# Get the sample | ||
df_index = self.indices[idx] | ||
row = self.df.iloc[df_index, :] | ||
eo, mask_per_token, latlon, _, target = self.row_to_arrays(row, self.target_function) | ||
eo, mask_per_token, latlon, _, target = self.row_to_arrays( | ||
row, self.target_function, self.augment | ||
) | ||
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) | ||
return ( | ||
self.normalize_and_mask(eo), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbutsko in
croptype
branch this file is removed. What does that mean for the changes here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged the functionality of
paper_eval
intotrain.py
. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something liketrain_self_supervised.py
andtrain_finetuned.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good plan!