-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using new parquet in train #104
base: main
Are you sure you want to change the base?
Changes from all commits
523014f
5c74855
cb197a3
9e9ac9d
b3e0284
1d37536
ae27e25
8d7e4c1
b2f1aa3
ff25509
e338c09
c9ffa01
f6e1a9e
407cec9
ad05b3d
2945e9d
ef06f94
30ab19c
96dbc0d
55dbbbe
a7eedd8
6f7646f
203c4ac
465d65a
54cc2be
e44445a
1a957a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
from .dataops import ( | ||
BANDS, | ||
BANDS_GROUPS_IDX, | ||
MIN_EDGE_BUFFER, | ||
NODATAVALUE, | ||
NORMED_BANDS, | ||
S1_S2_ERA5_SRTM, | ||
DynamicWorld2020_2021, | ||
|
@@ -32,7 +34,7 @@ | |
|
||
|
||
class WorldCerealBase(Dataset): | ||
_NODATAVALUE = 65535 | ||
# _NODATAVALUE = 65535 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cbutsko why is this commented? And are changes below like the addition of |
||
NUM_TIMESTEPS = 12 | ||
BAND_MAPPING = { | ||
"OPTICAL-B02-ts{}-10m": "B2", | ||
|
@@ -63,26 +65,97 @@ def target_crop(row_d: Dict) -> int: | |
# by default, we predict crop vs non crop | ||
return int(row_d["LANDCOVER_LABEL"] == 11) | ||
|
||
@classmethod | ||
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]: | ||
available_timesteps = int(row_d["available_timesteps"]) | ||
valid_position = int(row_d["valid_position"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment describing what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see its created in |
||
|
||
if not augment: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also make a note that augment here means temporal jittering? Am I right in understanding |
||
# check if the valid position is too close to the start_date and force shifting it | ||
if valid_position < cls.NUM_TIMESTEPS // 2: | ||
center_point = cls.NUM_TIMESTEPS // 2 | ||
# or too close to the end_date | ||
elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2): | ||
center_point = available_timesteps - cls.NUM_TIMESTEPS // 2 | ||
else: | ||
# Center the timesteps around the valid position | ||
center_point = valid_position | ||
else: | ||
# Shift the center point but make sure the resulting range | ||
# well includes the valid position | ||
|
||
min_center_point = max( | ||
cls.NUM_TIMESTEPS // 2, | ||
valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2, | ||
) | ||
max_center_point = min( | ||
available_timesteps - cls.NUM_TIMESTEPS // 2, | ||
valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2, | ||
) | ||
|
||
center_point = np.random.randint( | ||
min_center_point, max_center_point + 1 | ||
) # max_center_point included | ||
|
||
last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2) | ||
first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS) | ||
timestep_positions = list(range(first_timestep, last_timestep)) | ||
|
||
if len(timestep_positions) != cls.NUM_TIMESTEPS: | ||
raise ValueError( | ||
f"Acquired timestep positions do not have correct length: \ | ||
required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}" | ||
) | ||
assert ( | ||
valid_position in timestep_positions | ||
), f"Valid position {valid_position} not in timestep positions {timestep_positions}" | ||
return timestep_positions | ||
|
||
@classmethod | ||
def row_to_arrays( | ||
cls, row: pd.Series, target_function: Callable[[Dict], int] | ||
cls, row: pd.Series, target_function: Callable[[Dict], int], augment: bool = False | ||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]: | ||
# https://stackoverflow.com/questions/45783891/is-there-a-way-to-speed-up-the-pandas-getitem-getitem-axis-and-get-label | ||
# This is faster than indexing the series every time! | ||
row_d = pd.Series.to_dict(row) | ||
|
||
latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32) | ||
month = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month - 1 | ||
|
||
timestep_positions = cls.get_timestep_positions(row_d, augment=augment) | ||
|
||
if cls.NUM_TIMESTEPS == 12: | ||
initial_start_date_position = pd.to_datetime(row_d["start_date"]).month | ||
elif cls.NUM_TIMESTEPS > 12: | ||
# get the correct index of the start_date based on NUM_TIMESTEPS` | ||
# e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct | ||
# 10-day interval that the start_date falls into | ||
# TODO: 1) this needs to go into a separate function | ||
# 2) definition of valid_position and timestep_ind | ||
# should also be changed accordingly | ||
year = pd.to_datetime(row_d["start_date"]).year | ||
year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31") | ||
bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False) | ||
initial_start_date_position = bins[ | ||
np.where(year_dates == pd.to_datetime(row_d["start_date"]))[0][0] | ||
] | ||
else: | ||
raise ValueError( | ||
f"NUM_TIMESTEPS must be at least 12. Currently it is {cls.NUM_TIMESTEPS}" | ||
) | ||
|
||
# make sure that month for encoding gets shifted according to | ||
# the selected timestep positions. Also ensure circular indexing | ||
month = (initial_start_date_position - 1 + timestep_positions[0]) % cls.NUM_TIMESTEPS | ||
|
||
eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS))) | ||
# an assumption we make here is that all timesteps for a token | ||
# have the same masking | ||
mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX))) | ||
for df_val, presto_val in cls.BAND_MAPPING.items(): | ||
values = np.array([float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)]) | ||
values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions]) | ||
# this occurs for the DEM values in one point in Fiji | ||
values = np.nan_to_num(values, nan=cls._NODATAVALUE) | ||
idx_valid = values != cls._NODATAVALUE | ||
values = np.nan_to_num(values, nan=NODATAVALUE) | ||
idx_valid = values != NODATAVALUE | ||
if presto_val in ["VV", "VH"]: | ||
# convert to dB | ||
idx_valid = idx_valid & (values > 0) | ||
|
@@ -97,8 +170,8 @@ def row_to_arrays( | |
eo_data[:, BANDS.index(presto_val)] = values * idx_valid | ||
for df_val, presto_val in cls.STATIC_BAND_MAPPING.items(): | ||
# this occurs for the DEM values in one point in Fiji | ||
values = np.nan_to_num(row_d[df_val], nan=cls._NODATAVALUE) | ||
idx_valid = values != cls._NODATAVALUE | ||
values = np.nan_to_num(row_d[df_val], nan=NODATAVALUE) | ||
idx_valid = values != NODATAVALUE | ||
eo_data[:, BANDS.index(presto_val)] = values * idx_valid | ||
mask[:, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid | ||
|
||
|
@@ -129,7 +202,7 @@ def normalize_and_mask(cls, eo: np.ndarray): | |
keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"] | ||
normed_eo = S1_S2_ERA5_SRTM.normalize(eo) | ||
# TODO: fix this. For now, we replicate the previous behaviour | ||
normed_eo = np.where(eo[:, keep_indices] != cls._NODATAVALUE, normed_eo, 0) | ||
normed_eo = np.where(eo[:, keep_indices] != NODATAVALUE, normed_eo, 0) | ||
return normed_eo | ||
|
||
@staticmethod | ||
|
@@ -258,6 +331,7 @@ def __init__( | |
years_to_remove: Optional[List[int]] = None, | ||
target_function: Optional[Callable[[Dict], int]] = None, | ||
balance: bool = False, | ||
augment: bool = False, | ||
mask_ratio: float = 0.0, | ||
): | ||
dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)] | ||
|
@@ -274,6 +348,9 @@ def __init__( | |
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))] | ||
self.target_function = target_function if target_function is not None else self.target_crop | ||
self._class_weights: Optional[np.ndarray] = None | ||
self.augment = augment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. addressed that here ae27e25 |
||
if augment: | ||
logger.info("Augmentation is enabled. The valid_date position will be shifted.") | ||
self.mask_ratio = mask_ratio | ||
self.mask_params = MaskParamsNoDw( | ||
( | ||
|
@@ -367,7 +444,9 @@ def __getitem__(self, idx): | |
# Get the sample | ||
df_index = self.indices[idx] | ||
row = self.df.iloc[df_index, :] | ||
eo, mask_per_token, latlon, _, target = self.row_to_arrays(row, self.target_function) | ||
eo, mask_per_token, latlon, _, target = self.row_to_arrays( | ||
row, self.target_function, self.augment | ||
) | ||
if self.mask_ratio > 0: | ||
mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token) | ||
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1) | ||
|
@@ -382,7 +461,24 @@ def __getitem__(self, idx): | |
|
||
|
||
class WorldCerealInferenceDataset(Dataset): | ||
_NODATAVALUE = 65535 | ||
# _NODATAVALUE = 65535 | ||
Y = "worldcereal_cropland" | ||
BAND_MAPPING = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these changes will all conflict with croptype branch I think? |
||
"B02": "B2", | ||
"B03": "B3", | ||
"B04": "B4", | ||
"B05": "B5", | ||
"B06": "B6", | ||
"B07": "B7", | ||
"B08": "B8", | ||
# B8A is missing | ||
"B11": "B11", | ||
"B12": "B12", | ||
"VH": "VH", | ||
"VV": "VV", | ||
"precipitation-flux": "total_precipitation", | ||
"temperature-mean": "temperature_2m", | ||
} | ||
Y = "WORLDCEREAL_TEMPORARYCROPS_2021" | ||
|
||
def __init__(self, path_to_files: Path = data_dir / "inference_areas"): | ||
|
@@ -408,7 +504,7 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: | |
|
||
# Handle NaN values in Presto compatible way | ||
inarr = inarr.astype(np.float32) | ||
inarr = inarr.fillna(65535) | ||
inarr = inarr.fillna(NODATAVALUE) | ||
|
||
eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS))) | ||
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX))) | ||
|
@@ -420,7 +516,7 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: | |
0, | ||
1, | ||
) | ||
idx_valid = values != cls._NODATAVALUE | ||
idx_valid = values != NODATAVALUE | ||
values = cls._preprocess_band_values(values, presto_band) | ||
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid | ||
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid | ||
|
@@ -552,7 +648,7 @@ def nc_to_arrays( | |
months = cls._extract_months(inarr) | ||
|
||
if cls.Y not in ds: | ||
target = np.ones_like(months) * cls._NODATAVALUE | ||
target = np.ones_like(months) * NODATAVALUE | ||
else: | ||
target = rearrange(inarr.sel(bands=cls.Y).values, "t x y -> (x y) t") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbutsko in
croptype
branch this file is removed. What does that mean for the changes here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged the functionality of
paper_eval
intotrain.py
. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something liketrain_self_supervised.py
andtrain_finetuned.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good plan!