Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using new parquet in train #104

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
523014f
removed process_parquet function to utils
Sep 10, 2024
5c74855
addeded augment parameter
Sep 10, 2024
cb197a3
added function for timeseries subsetting, so that it is centered arou…
Sep 10, 2024
9e9ac9d
added augment parameter; replaced default link to new parquet file; a…
Sep 10, 2024
b3e0284
major rework of process_parquet function; minimal viable functionality
Sep 10, 2024
1d37536
moved MIN_EDGE_BUFFER parameter from utils to dataset.py
Sep 11, 2024
ae27e25
added logger message about enabled augmentation
Sep 11, 2024
8d7e4c1
removed augment=False parameter from evaluate function, since it is a…
Sep 11, 2024
b2f1aa3
rephrased checking if valid_date is too close to the edge without mes…
Sep 11, 2024
ff25509
bugs and typos fixes
Sep 11, 2024
e338c09
moved NODATA and MIN_EGDE parameters to dataops.py to avoid circular …
Sep 11, 2024
c9ffa01
updated test dataset to use new ong parquet format
Sep 11, 2024
f6e1a9e
updated tests
Sep 11, 2024
407cec9
created separate test file for process_parquet function
Sep 11, 2024
ad05b3d
an attempt to make time_token shift more general than just for months
Sep 11, 2024
2945e9d
merging changes from main
Sep 20, 2024
ef06f94
black fix
Sep 20, 2024
30ab19c
adding test long parquet file
Sep 23, 2024
96dbc0d
fixed test file path
Sep 23, 2024
55dbbbe
isort fix
Sep 23, 2024
a7eedd8
fixed test and commented lines that will not be needed after merge
Sep 23, 2024
6f7646f
Formatting
kvantricht Sep 23, 2024
203c4ac
Formatting
kvantricht Sep 23, 2024
465d65a
making GT values binary crop/nocrop
Sep 23, 2024
54cc2be
Test with 1 epoch finetuning
kvantricht Sep 23, 2024
e44445a
Merge branch 'using-new-parquet-in-train' of github.com:WorldCereal/p…
kvantricht Sep 23, 2024
1a957a2
Bump einops version
kvantricht Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions paper_eval.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko in croptype branch this file is removed. What does that mean for the changes here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the functionality of paper_eval into train.py. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something like train_self_supervised.py and train_finetuned.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good plan!

Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
# presto_pretrain_finetune, but in a notebook
import argparse
import gc
import json
import logging
from glob import glob
from pathlib import Path
from typing import Optional, cast

import pandas as pd
import torch
import xarray as xr

from presto.dataset import WorldCerealBase
from presto.eval import WorldCerealEval
from presto.presto import Presto
from presto.utils import (
DEFAULT_SEED,
NODATAVALUE,
config_dir,
data_dir,
default_model_path,
device,
initialize_logging,
plot_spatial,
process_parquet,
seed_everything,
timestamp_dirname,
)
from tqdm.auto import tqdm

logger = logging.getLogger("__main__")

Expand All @@ -41,12 +45,20 @@
argparser.add_argument("--num_workers", type=int, default=4)
argparser.add_argument("--wandb", dest="wandb", action="store_true")
argparser.add_argument("--wandb_org", type=str, default="nasa-harvest")
argparser.add_argument("--parquet_file", type=str, default="rawts-monthly_calval.parquet")
# argparser.add_argument("--parquet_file", type=str, default="rawts-monthly_calval.parquet")
argparser.add_argument(
"--parquet_file",
type=str,
default="/vitodata/worldcereal/features/preprocessedinputs-monthly-nointerp/\
worldcereal_training_data.parquet",
)
argparser.add_argument("--val_samples_file", type=str, default="cropland_test_split_samples.csv")
argparser.add_argument("--train_only_samples_file", type=str, default="train_only_samples.csv")
argparser.add_argument("--warm_start", dest="warm_start", action="store_true")
argparser.add_argument("--augment", dest="augment", action="store_true")
argparser.set_defaults(wandb=False)
argparser.set_defaults(warm_start=True)
argparser.set_defaults(augment=False)
args = argparser.parse_args().__dict__

model_name = args["model_name"]
Expand Down Expand Up @@ -79,6 +91,7 @@
parquet_file: str = args["parquet_file"]
val_samples_file: str = args["val_samples_file"]
train_only_samples_file: str = args["train_only_samples_file"]
augment: bool = args["augment"]

dekadal = False
if "10d" in parquet_file:
Expand All @@ -89,7 +102,21 @@

logger.info("Setting up dataloaders")

df = pd.read_parquet(data_dir / parquet_file)

logger.info("Reading dataset")
files = sorted(glob(f"{parquet_file}/**/*.parquet"))[:10]
df_list = []
for f in tqdm(files):
_data = pd.read_parquet(f, engine="fastparquet")
_data_pivot = process_parquet(_data)
_data_pivot.reset_index(inplace=True)
df_list.append(_data_pivot)
del _data, _data_pivot
kvantricht marked this conversation as resolved.
Show resolved Hide resolved
gc.collect()
df = pd.concat(df_list)
df = df.fillna(NODATAVALUE)
del df_list
gc.collect()

logger.info("Setting up model")
if warm_start:
Expand All @@ -104,13 +131,14 @@
best_model_path = None
model.to(device)

model_modes = ["Random Forest", "Regression", "CatBoostClassifier"]
# model_modes = ["Random Forest", "Regression", "CatBoostClassifier"]
model_modes = ["CatBoostClassifier"]

# 1. Using the provided split
val_samples_df = pd.read_csv(data_dir / val_samples_file)
train_df, test_df = WorldCerealBase.split_df(df, val_sample_ids=val_samples_df.sample_id.tolist())
full_eval = WorldCerealEval(
train_df, test_df, spatial_inference_savedir=model_logging_dir, dekadal=dekadal
train_df, test_df, spatial_inference_savedir=model_logging_dir, dekadal=dekadal, augment=False
kvantricht marked this conversation as resolved.
Show resolved Hide resolved
)
results, finetuned_model = full_eval.finetuning_results(model, sklearn_model_modes=model_modes)
logger.info(json.dumps(results, indent=2))
Expand Down
68 changes: 62 additions & 6 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DynamicWorld2020_2021,
)
from .masking import BAND_EXPANSION, MaskedExample, MaskParamsNoDw
from .utils import DEFAULT_SEED, data_dir, load_world_df
from .utils import DEFAULT_SEED, MIN_EDGE_BUFFER, data_dir, load_world_df

logger = logging.getLogger("__main__")

Expand Down Expand Up @@ -65,23 +65,73 @@ def target_crop(row_d: Dict) -> int:
# by default, we predict crop vs non crop
return int(row_d["LANDCOVER_LABEL"] == 11)

@classmethod
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]:
available_timesteps = int(row_d["available_timesteps"])
valid_position = int(row_d["valid_position"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment describing what valid_position represents?

Copy link
Collaborator

@gabrieltseng gabrieltseng Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see its created in utils.py but its still not obvious to me what it represents - could you add a comment there instead?


# force moving the center point if it is too close to the edges
if (valid_position < cls.NUM_TIMESTEPS // 2) or (
valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2)
):
augment = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we enter unconsciously in augmentation which might not be what we desire. Can't we put this logic inside if not augment and in that case not choose the valid_position as the center_point but rather the the point that keeps valid_position closest as possible to the center? Then it's always deterministic and we don't have to force going through the augmentation part.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. tried to rewrite it here b2f1aa3


if not augment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also make a note that augment here means temporal jittering? Am I right in understanding MIN_EDGE_BUFFER is just used to determine how much temporal jittering is allowed? In this case maybe it makes sense as a class attribute (or even as a default argument to this function)

# Center the timesteps around the valid position
center_point = valid_position
else:
# Shift the center point but make sure the resulting range
# well includes the valid position

min_center_point = max(
cls.NUM_TIMESTEPS // 2, valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2
)
max_center_point = min(
available_timesteps - cls.NUM_TIMESTEPS // 2,
valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2,
)

center_point = np.random.randint(
min_center_point, max_center_point + 1
) # max_center_point included

last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2)
first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS)
timestep_positions = list(range(first_timestep, last_timestep))

if len(timestep_positions) != cls.NUM_TIMESTEPS:
raise ValueError(
f"Acquired timestep positions do not have correct length: \
required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}"
)
assert (
valid_position in timestep_positions
), f"Valid position {valid_position} not in timestep positions {timestep_positions}"
return timestep_positions

@classmethod
def row_to_arrays(
cls, row: pd.Series, target_function: Callable[[Dict], int]
cls, row: pd.Series, target_function: Callable[[Dict], int], augment: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, int]:
# https://stackoverflow.com/questions/45783891/is-there-a-way-to-speed-up-the-pandas-getitem-getitem-axis-and-get-label
# This is faster than indexing the series every time!
row_d = pd.Series.to_dict(row)

latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32)
month = datetime.strptime(row_d["start_date"], "%Y-%m-%d").month - 1

timestep_positions = cls.get_timestep_positions(row_d, augment=augment)
# make sure that month for encoding gets shifted according to
# the selected timestep positions
month = (
pd.to_datetime(row_d["start_date"]) + pd.DateOffset(months=timestep_positions[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also just add timestep_positions[0] and not use the overhead of the pd.DateOffset method. I don't think we're resilient in any case to situations where we're working with dekadal data. This is a potential risk if someone (including ourselves) is going over the dekadal track. Can we make it more universal here? Or if needed for now do a check and raise something, so nobody blindly takes this while it won't work as expected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. tried to account for that here ad05b3d
we will also need to think about renaming this month thing (since it can be something else too), and also making sure that other relative timestep positions (valid_position and timestep_ind) are computed not as month, but in a more generic fashion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean this has to be tackled in process_parquet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what initial_start_date_position actually means and why it cannot just be the month inferred from start_date.

Copy link
Author

@cbutsko cbutsko Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, maybe I'm overthinking it...

here's an example:

  1. we are working on a monthly basis, with start_date in October (hence, month = 10).
    and we want to shift it 4 timesteps forward. we can just make (10 + 4) and add the modulo part not to get bad month value % cls.NUM_TIMESTEPS

  2. we are working on a dekadal basis. start date is the same. I assume that our timestep indices are not in month chunks, but in 10-day intervals. so, we have observations for ts0, ts1, ..., ts45 (for example). our NUM_TIMESTEPS variable should be set to 36 instead of 12 (like Giorgia did). the valid_date should also translate into 10-day chunk instead of month. so that we can select 36 timesteps around the valid_position.
    now, we have our timestep_positions returned in dekadal steps. so when we add 4 dekadal steps to a month value, it doesn't make sense. we need to add apples to apples to get the date that accounts for the shift.
    and now we can take it's month and pass it to presto. I realize now that this particular step is missing in my implementation.

am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following most of it. But still, start_date is a real date, from which we can infer what would normally be the start month being fed to Presto. Why do we need to translate it to a position? I agree that what we have to add to it due to the shift should account for the time "resolution". valid_position should be the translation of valid_date (which is irrespective of time resolution) to the position in the timesteps (which is depending on the time resolution). I might be overthinking it just as much.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, maybe something like this can work:

step_converted_to_month = np.ceil(timestep_positions[0] * (365 // NUM_TIMESTEPS) / 30)
month = (pd.to_datetime(start_date).month + step_converted_to_month) % 12 - 1

it can be a little imprecise in some cases, not more than 1 month error.
but it's just a one-liner, so probably we can sacrifice this little bit of precision

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's make a note that Giorgia can test this out properly when doing dekadal runs.

).month - 1

eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS)))
# an assumption we make here is that all timesteps for a token
# have the same masking
mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX)))
for df_val, presto_val in cls.BAND_MAPPING.items():
values = np.array([float(row_d[df_val.format(t)]) for t in range(cls.NUM_TIMESTEPS)])
values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions])
# this occurs for the DEM values in one point in Fiji
values = np.nan_to_num(values, nan=cls._NODATAVALUE)
idx_valid = values != cls._NODATAVALUE
Expand Down Expand Up @@ -260,6 +310,7 @@ def __init__(
years_to_remove: Optional[List[int]] = None,
target_function: Optional[Callable[[Dict], int]] = None,
balance: bool = False,
augment: bool = False,
):
dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)]

Expand All @@ -275,6 +326,7 @@ def __init__(
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))]
self.target_function = target_function if target_function is not None else self.target_crop
self._class_weights: Optional[np.ndarray] = None
self.augment = augment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. addressed that here ae27e25


super().__init__(dataframe)
if balance:
Expand Down Expand Up @@ -307,7 +359,9 @@ def __getitem__(self, idx):
# Get the sample
df_index = self.indices[idx]
row = self.df.iloc[df_index, :]
eo, mask_per_token, latlon, month, target = self.row_to_arrays(row, self.target_function)
eo, mask_per_token, latlon, month, target = self.row_to_arrays(
row, self.target_function, self.augment
)
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1)
return (
self.normalize_and_mask(eo),
Expand Down Expand Up @@ -354,7 +408,9 @@ def __getitem__(self, idx):
# Get the sample
df_index = self.indices[idx]
row = self.df.iloc[df_index, :]
eo, mask_per_token, latlon, _, target = self.row_to_arrays(row, self.target_function)
eo, mask_per_token, latlon, _, target = self.row_to_arrays(
row, self.target_function, self.augment
)
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1)
return (
self.normalize_and_mask(eo),
Expand Down
9 changes: 8 additions & 1 deletion presto/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
name: Optional[str] = None,
val_size: float = 0.2,
dekadal: bool = False,
augment: bool = False,
):
self.seed = seed

Expand All @@ -90,6 +91,8 @@ def __init__(
self.dekadal = dekadal
self.ds_class = WorldCerealLabelled10DDataset if dekadal else WorldCerealLabelledDataset

self.augment = augment

def _construct_finetuning_model(self, pretrained_model: Presto) -> PrestoFineTuningModel:
model: PrestoFineTuningModel = cast(Callable, pretrained_model.construct_finetuning_model)(
num_outputs=self.num_outputs
Expand Down Expand Up @@ -278,7 +281,7 @@ def evaluate(
pretrained_model: Optional[PrestoFineTuningModel] = None,
) -> Dict:

test_ds = self.ds_class(self.test_df, target_function=self.target_function)
test_ds = self.ds_class(self.test_df, target_function=self.target_function, augment=False)
kvantricht marked this conversation as resolved.
Show resolved Hide resolved
dl = DataLoader(
test_ds,
batch_size=512,
Expand Down Expand Up @@ -387,6 +390,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel:
years_to_remove=self.years_to_remove,
target_function=self.target_function,
balance=True,
augment=self.augment,
)

# should the val set be balanced too?
Expand All @@ -395,6 +399,7 @@ def finetune(self, pretrained_model) -> PrestoFineTuningModel:
countries_to_remove=self.countries_to_remove,
years_to_remove=self.years_to_remove,
target_function=self.target_function,
augment=False, # don't augment the validation set
)

loss_fn = nn.BCEWithLogitsLoss()
Expand Down Expand Up @@ -511,6 +516,7 @@ def finetuning_results_sklearn(
countries_to_remove=self.countries_to_remove,
years_to_remove=self.years_to_remove,
target_function=self.target_function,
augment=self.augment,
),
batch_size=2048,
shuffle=False,
Expand All @@ -522,6 +528,7 @@ def finetuning_results_sklearn(
countries_to_remove=self.countries_to_remove,
years_to_remove=self.years_to_remove,
target_function=self.target_function,
augment=False, # don't augment the validation set
),
batch_size=2048,
shuffle=False,
Expand Down
Loading