Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using new parquet in train #104

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
523014f
removed process_parquet function to utils
Sep 10, 2024
5c74855
addeded augment parameter
Sep 10, 2024
cb197a3
added function for timeseries subsetting, so that it is centered arou…
Sep 10, 2024
9e9ac9d
added augment parameter; replaced default link to new parquet file; a…
Sep 10, 2024
b3e0284
major rework of process_parquet function; minimal viable functionality
Sep 10, 2024
1d37536
moved MIN_EDGE_BUFFER parameter from utils to dataset.py
Sep 11, 2024
ae27e25
added logger message about enabled augmentation
Sep 11, 2024
8d7e4c1
removed augment=False parameter from evaluate function, since it is a…
Sep 11, 2024
b2f1aa3
rephrased checking if valid_date is too close to the edge without mes…
Sep 11, 2024
ff25509
bugs and typos fixes
Sep 11, 2024
e338c09
moved NODATA and MIN_EGDE parameters to dataops.py to avoid circular …
Sep 11, 2024
c9ffa01
updated test dataset to use new ong parquet format
Sep 11, 2024
f6e1a9e
updated tests
Sep 11, 2024
407cec9
created separate test file for process_parquet function
Sep 11, 2024
ad05b3d
an attempt to make time_token shift more general than just for months
Sep 11, 2024
2945e9d
merging changes from main
Sep 20, 2024
ef06f94
black fix
Sep 20, 2024
30ab19c
adding test long parquet file
Sep 23, 2024
96dbc0d
fixed test file path
Sep 23, 2024
55dbbbe
isort fix
Sep 23, 2024
a7eedd8
fixed test and commented lines that will not be needed after merge
Sep 23, 2024
6f7646f
Formatting
kvantricht Sep 23, 2024
203c4ac
Formatting
kvantricht Sep 23, 2024
465d65a
making GT values binary crop/nocrop
Sep 23, 2024
54cc2be
Test with 1 epoch finetuning
kvantricht Sep 23, 2024
e44445a
Merge branch 'using-new-parquet-in-train' of github.com:WorldCereal/p…
kvantricht Sep 23, 2024
1a957a2
Bump einops version
kvantricht Sep 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,557 changes: 1,557 additions & 0 deletions catboost_info/catboost_training.json

Large diffs are not rendered by default.

Binary file added catboost_info/learn/events.out.tfevents
Binary file not shown.
1,554 changes: 1,554 additions & 0 deletions catboost_info/learn_error.tsv

Large diffs are not rendered by default.

Binary file added catboost_info/test/events.out.tfevents
Binary file not shown.
1,554 changes: 1,554 additions & 0 deletions catboost_info/test_error.tsv

Large diffs are not rendered by default.

1,554 changes: 1,554 additions & 0 deletions catboost_info/time_left.tsv

Large diffs are not rendered by default.

Binary file not shown.
40 changes: 34 additions & 6 deletions paper_eval.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko in croptype branch this file is removed. What does that mean for the changes here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the functionality of paper_eval into train.py. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something like train_self_supervised.py and train_finetuned.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good plan!

Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
# presto_pretrain_finetune, but in a notebook
import argparse
import gc
import json
import logging
from glob import glob
from pathlib import Path
from typing import Optional, cast

import pandas as pd
import torch
import xarray as xr
from tqdm.auto import tqdm

from presto.dataset import WorldCerealBase
from presto.eval import WorldCerealEval
from presto.presto import Presto
from presto.utils import (
DEFAULT_SEED,
NODATAVALUE,
config_dir,
data_dir,
default_model_path,
device,
initialize_logging,
plot_spatial,
process_parquet,
seed_everything,
timestamp_dirname,
)
Expand All @@ -41,12 +46,20 @@
argparser.add_argument("--num_workers", type=int, default=4)
argparser.add_argument("--wandb", dest="wandb", action="store_true")
argparser.add_argument("--wandb_org", type=str, default="nasa-harvest")
argparser.add_argument("--parquet_file", type=str, default="rawts-monthly_calval.parquet")
# argparser.add_argument("--parquet_file", type=str, default="rawts-monthly_calval.parquet")
argparser.add_argument(
"--parquet_file",
type=str,
default="/vitodata/worldcereal/features/preprocessedinputs-monthly-nointerp/\
worldcereal_training_data.parquet",
)
argparser.add_argument("--val_samples_file", type=str, default="cropland_test_split_samples.csv")
argparser.add_argument("--train_only_samples_file", type=str, default="train_only_samples.csv")
argparser.add_argument("--warm_start", dest="warm_start", action="store_true")
argparser.add_argument("--augment", dest="augment", action="store_true")
argparser.set_defaults(wandb=False)
argparser.set_defaults(warm_start=True)
argparser.set_defaults(augment=False)
args = argparser.parse_args().__dict__

model_name = args["model_name"]
Expand Down Expand Up @@ -79,6 +92,7 @@
parquet_file: str = args["parquet_file"]
val_samples_file: str = args["val_samples_file"]
train_only_samples_file: str = args["train_only_samples_file"]
augment: bool = args["augment"]

dekadal = False
if "10d" in parquet_file:
Expand All @@ -87,9 +101,18 @@
path_to_config = config_dir / "default.json"
model_kwargs = json.load(Path(path_to_config).open("r"))

logger.info("Setting up dataloaders")

df = pd.read_parquet(data_dir / parquet_file)
logger.info("Reading dataset")
files = sorted(glob(f"{parquet_file}/**/*.parquet"))[:10]
df_list = []
for f in tqdm(files):
_data = pd.read_parquet(f, engine="fastparquet")
_data_pivot = process_parquet(_data)
_data_pivot.reset_index(inplace=True)
df_list.append(_data_pivot)
df = pd.concat(df_list)
df = df.fillna(NODATAVALUE)
del df_list
gc.collect()

logger.info("Setting up model")
if warm_start:
Expand All @@ -104,13 +127,18 @@
best_model_path = None
model.to(device)

model_modes = ["Random Forest", "Regression", "CatBoostClassifier"]
# model_modes = ["Random Forest", "Regression", "CatBoostClassifier"]
model_modes = ["CatBoostClassifier"]

# 1. Using the provided split
val_samples_df = pd.read_csv(data_dir / val_samples_file)
train_df, test_df = WorldCerealBase.split_df(df, val_sample_ids=val_samples_df.sample_id.tolist())
full_eval = WorldCerealEval(
train_df, test_df, spatial_inference_savedir=model_logging_dir, dekadal=dekadal
train_df,
test_df,
spatial_inference_savedir=model_logging_dir,
dekadal=dekadal,
augment=augment,
)
results, finetuned_model = full_eval.finetuning_results(model, sklearn_model_modes=model_modes)
logger.info(json.dumps(results, indent=2))
Expand Down
2 changes: 2 additions & 0 deletions presto/dataops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
NUM_TIMESTEPS = 12
NUM_ORG_BANDS = len(BANDS)
TIMESTEPS_IDX = list(range(NUM_TIMESTEPS))
NODATAVALUE = 65535
MIN_EDGE_BUFFER = 2 # Min amount of timesteps to include before/after the valid position

NORMED_BANDS = [x for x in BANDS if x != "B9"]
NUM_BANDS = len(NORMED_BANDS)
Expand Down
124 changes: 110 additions & 14 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .dataops import (
BANDS,
BANDS_GROUPS_IDX,
MIN_EDGE_BUFFER,
NODATAVALUE,
NORMED_BANDS,
S1_S2_ERA5_SRTM,
DynamicWorld2020_2021,
Expand All @@ -32,7 +34,7 @@


class WorldCerealBase(Dataset):
_NODATAVALUE = 65535
# _NODATAVALUE = 65535
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko why is this commented? And are changes below like the addition of get_timestep_positions not part of the other branch?

NUM_TIMESTEPS = 12
BAND_MAPPING = {
"OPTICAL-B02-ts{}-10m": "B2",
Expand Down Expand Up @@ -63,26 +65,97 @@ def target_crop(row_d: Dict) -> int:
# by default, we predict crop vs non crop
return int(row_d["LANDCOVER_LABEL"] == 11)

@classmethod
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]:
available_timesteps = int(row_d["available_timesteps"])
valid_position = int(row_d["valid_position"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment describing what valid_position represents?

Copy link
Collaborator

@gabrieltseng gabrieltseng Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see its created in utils.py but its still not obvious to me what it represents - could you add a comment there instead?


if not augment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also make a note that augment here means temporal jittering? Am I right in understanding MIN_EDGE_BUFFER is just used to determine how much temporal jittering is allowed? In this case maybe it makes sense as a class attribute (or even as a default argument to this function)

# check if the valid position is too close to the start_date and force shifting it
if valid_position < cls.NUM_TIMESTEPS // 2:
center_point = cls.NUM_TIMESTEPS // 2
# or too close to the end_date
elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2):
center_point = available_timesteps - cls.NUM_TIMESTEPS // 2
else:
# Center the timesteps around the valid position
center_point = valid_position
else:
# Shift the center point but make sure the resulting range
# well includes the valid position

min_center_point = max(
cls.NUM_TIMESTEPS // 2,
valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2,
)
max_center_point = min(
available_timesteps - cls.NUM_TIMESTEPS // 2,
valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2,
)

center_point = np.random.randint(
min_center_point, max_center_point + 1
) # max_center_point included

last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2)
first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS)
timestep_positions = list(range(first_timestep, last_timestep))

if len(timestep_positions) != cls.NUM_TIMESTEPS:
raise ValueError(
f"Acquired timestep positions do not have correct length: \
required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}"
)
assert (
valid_position in timestep_positions
), f"Valid position {valid_position} not in timestep positions {timestep_positions}"
return timestep_positions

@classmethod
def row_to_arrays(
cls, row: pd.Series, target_function: Callable[[Dict], int]
cls, row: pd.Series, target_function: Callable[[Dict], int], augment: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]:
# https://stackoverflow.com/questions/45783891/is-there-a-way-to-speed-up-the-pandas-getitem-getitem-axis-and-get-label
# This is faster than indexing the series every time!
row_d = pd.Series.to_dict(row)

latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32)
month = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month - 1

timestep_positions = cls.get_timestep_positions(row_d, augment=augment)

if cls.NUM_TIMESTEPS == 12:
initial_start_date_position = pd.to_datetime(row_d["start_date"]).month
elif cls.NUM_TIMESTEPS > 12:
# get the correct index of the start_date based on NUM_TIMESTEPS`
# e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct
# 10-day interval that the start_date falls into
# TODO: 1) this needs to go into a separate function
# 2) definition of valid_position and timestep_ind
# should also be changed accordingly
year = pd.to_datetime(row_d["start_date"]).year
year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31")
bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False)
initial_start_date_position = bins[
np.where(year_dates == pd.to_datetime(row_d["start_date"]))[0][0]
]
else:
raise ValueError(
f"NUM_TIMESTEPS must be at least 12. Currently it is {cls.NUM_TIMESTEPS}"
)

# make sure that month for encoding gets shifted according to
# the selected timestep positions. Also ensure circular indexing
month = (initial_start_date_position - 1 + timestep_positions[0]) % cls.NUM_TIMESTEPS

eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS)))
# an assumption we make here is that all timesteps for a token
# have the same masking
mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX)))
for df_val, presto_val in cls.BAND_MAPPING.items():
values = np.array([float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)])
values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions])
# this occurs for the DEM values in one point in Fiji
values = np.nan_to_num(values, nan=cls._NODATAVALUE)
idx_valid = values != cls._NODATAVALUE
values = np.nan_to_num(values, nan=NODATAVALUE)
idx_valid = values != NODATAVALUE
if presto_val in ["VV", "VH"]:
# convert to dB
idx_valid = idx_valid & (values > 0)
Expand All @@ -97,8 +170,8 @@ def row_to_arrays(
eo_data[:, BANDS.index(presto_val)] = values * idx_valid
for df_val, presto_val in cls.STATIC_BAND_MAPPING.items():
# this occurs for the DEM values in one point in Fiji
values = np.nan_to_num(row_d[df_val], nan=cls._NODATAVALUE)
idx_valid = values != cls._NODATAVALUE
values = np.nan_to_num(row_d[df_val], nan=NODATAVALUE)
idx_valid = values != NODATAVALUE
eo_data[:, BANDS.index(presto_val)] = values * idx_valid
mask[:, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid

Expand Down Expand Up @@ -129,7 +202,7 @@ def normalize_and_mask(cls, eo: np.ndarray):
keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"]
normed_eo = S1_S2_ERA5_SRTM.normalize(eo)
# TODO: fix this. For now, we replicate the previous behaviour
normed_eo = np.where(eo[:, keep_indices] != cls._NODATAVALUE, normed_eo, 0)
normed_eo = np.where(eo[:, keep_indices] != NODATAVALUE, normed_eo, 0)
return normed_eo

@staticmethod
Expand Down Expand Up @@ -258,6 +331,7 @@ def __init__(
years_to_remove: Optional[List[int]] = None,
target_function: Optional[Callable[[Dict], int]] = None,
balance: bool = False,
augment: bool = False,
mask_ratio: float = 0.0,
):
dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)]
Expand All @@ -274,6 +348,9 @@ def __init__(
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))]
self.target_function = target_function if target_function is not None else self.target_crop
self._class_weights: Optional[np.ndarray] = None
self.augment = augment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. addressed that here ae27e25

if augment:
logger.info("Augmentation is enabled. The valid_date position will be shifted.")
self.mask_ratio = mask_ratio
self.mask_params = MaskParamsNoDw(
(
Expand Down Expand Up @@ -367,7 +444,9 @@ def __getitem__(self, idx):
# Get the sample
df_index = self.indices[idx]
row = self.df.iloc[df_index, :]
eo, mask_per_token, latlon, _, target = self.row_to_arrays(row, self.target_function)
eo, mask_per_token, latlon, _, target = self.row_to_arrays(
row, self.target_function, self.augment
)
if self.mask_ratio > 0:
mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token)
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1)
Expand All @@ -382,7 +461,24 @@ def __getitem__(self, idx):


class WorldCerealInferenceDataset(Dataset):
_NODATAVALUE = 65535
# _NODATAVALUE = 65535
Y = "worldcereal_cropland"
BAND_MAPPING = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes will all conflict with croptype branch I think?

"B02": "B2",
"B03": "B3",
"B04": "B4",
"B05": "B5",
"B06": "B6",
"B07": "B7",
"B08": "B8",
# B8A is missing
"B11": "B11",
"B12": "B12",
"VH": "VH",
"VV": "VV",
"precipitation-flux": "total_precipitation",
"temperature-mean": "temperature_2m",
}
Y = "WORLDCEREAL_TEMPORARYCROPS_2021"

def __init__(self, path_to_files: Path = data_dir / "inference_areas"):
Expand All @@ -408,7 +504,7 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:

# Handle NaN values in Presto compatible way
inarr = inarr.astype(np.float32)
inarr = inarr.fillna(65535)
inarr = inarr.fillna(NODATAVALUE)

eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS)))
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX)))
Expand All @@ -420,7 +516,7 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
0,
1,
)
idx_valid = values != cls._NODATAVALUE
idx_valid = values != NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
Expand Down Expand Up @@ -552,7 +648,7 @@ def nc_to_arrays(
months = cls._extract_months(inarr)

if cls.Y not in ds:
target = np.ones_like(months) * cls._NODATAVALUE
target = np.ones_like(months) * NODATAVALUE
else:
target = rearrange(inarr.sel(bands=cls.Y).values, "t x y -> (x y) t")

Expand Down
Loading
Loading