Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using new parquet in train #104

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft

Using new parquet in train #104

wants to merge 27 commits into from

Conversation

cbutsko
Copy link

@cbutsko cbutsko commented Sep 10, 2024

No description provided.

Copy link
Contributor

@kvantricht kvantricht left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! Added already some comments. The process_parquet is something that's for me hard to read just as code. Let's start by testing it thoroughly.

paper_eval.py Outdated Show resolved Hide resolved
paper_eval.py Outdated Show resolved Hide resolved
if (valid_position < cls.NUM_TIMESTEPS // 2) or (
valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2)
):
augment = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we enter unconsciously in augmentation which might not be what we desire. Can't we put this logic inside if not augment and in that case not choose the valid_position as the center_point but rather the the point that keeps valid_position closest as possible to the center? Then it's always deterministic and we don't have to force going through the augmentation part.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. tried to rewrite it here b2f1aa3

# make sure that month for encoding gets shifted according to
# the selected timestep positions
month = (
pd.to_datetime(row_d["start_date"]) + pd.DateOffset(months=timestep_positions[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also just add timestep_positions[0] and not use the overhead of the pd.DateOffset method. I don't think we're resilient in any case to situations where we're working with dekadal data. This is a potential risk if someone (including ourselves) is going over the dekadal track. Can we make it more universal here? Or if needed for now do a check and raise something, so nobody blindly takes this while it won't work as expected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. tried to account for that here ad05b3d
we will also need to think about renaming this month thing (since it can be something else too), and also making sure that other relative timestep positions (valid_position and timestep_ind) are computed not as month, but in a more generic fashion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean this has to be tackled in process_parquet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what initial_start_date_position actually means and why it cannot just be the month inferred from start_date.

Copy link
Author

@cbutsko cbutsko Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, maybe I'm overthinking it...

here's an example:

  1. we are working on a monthly basis, with start_date in October (hence, month = 10).
    and we want to shift it 4 timesteps forward. we can just make (10 + 4) and add the modulo part not to get bad month value % cls.NUM_TIMESTEPS

  2. we are working on a dekadal basis. start date is the same. I assume that our timestep indices are not in month chunks, but in 10-day intervals. so, we have observations for ts0, ts1, ..., ts45 (for example). our NUM_TIMESTEPS variable should be set to 36 instead of 12 (like Giorgia did). the valid_date should also translate into 10-day chunk instead of month. so that we can select 36 timesteps around the valid_position.
    now, we have our timestep_positions returned in dekadal steps. so when we add 4 dekadal steps to a month value, it doesn't make sense. we need to add apples to apples to get the date that accounts for the shift.
    and now we can take it's month and pass it to presto. I realize now that this particular step is missing in my implementation.

am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following most of it. But still, start_date is a real date, from which we can infer what would normally be the start month being fed to Presto. Why do we need to translate it to a position? I agree that what we have to add to it due to the shift should account for the time "resolution". valid_position should be the translation of valid_date (which is irrespective of time resolution) to the position in the timesteps (which is depending on the time resolution). I might be overthinking it just as much.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, maybe something like this can work:

step_converted_to_month = np.ceil(timestep_positions[0] * (365 // NUM_TIMESTEPS) / 30)
month = (pd.to_datetime(start_date).month + step_converted_to_month) % 12 - 1

it can be a little imprecise in some cases, not more than 1 month error.
but it's just a one-liner, so probably we can sacrifice this little bit of precision

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's make a note that Giorgia can test this out properly when doing dekadal runs.

presto/eval.py Outdated Show resolved Hide resolved
@@ -275,6 +326,7 @@ def __init__(
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))]
self.target_function = target_function if target_function is not None else self.target_crop
self._class_weights: Optional[np.ndarray] = None
self.augment = augment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. addressed that here ae27e25

presto/utils.py Outdated Show resolved Hide resolved
# add dummy value + rename stuff for compatibility with existing functions
df["OPTICAL-B8A"] = NODATAVALUE

# TODO: this needs to go away once the transition to new data is complete
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, and can we immediately go to presto naming convention then as well?

@kvantricht
Copy link
Contributor

@cbutsko is this branch already part of croptype? You were training on new parquet format already? But maybe only locally?

@cbutsko
Copy link
Author

cbutsko commented Sep 23, 2024

@kvantricht this branch is not yet the part of croptype. My plan is to sync it now with main, and then merge it directly into croptype, as you suggested.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko in croptype branch this file is removed. What does that mean for the changes here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the functionality of paper_eval into train.py. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something like train_self_supervised.py and train_finetuned.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good plan!

@@ -32,7 +34,7 @@


class WorldCerealBase(Dataset):
_NODATAVALUE = 65535
# _NODATAVALUE = 65535
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko why is this commented? And are changes below like the addition of get_timestep_positions not part of the other branch?

_NODATAVALUE = 65535
# _NODATAVALUE = 65535
Y = "worldcereal_cropland"
BAND_MAPPING = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes will all conflict with croptype branch I think?

@classmethod
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]:
available_timesteps = int(row_d["available_timesteps"])
valid_position = int(row_d["valid_position"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment describing what valid_position represents?

Copy link
Collaborator

@gabrieltseng gabrieltseng Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see its created in utils.py but its still not obvious to me what it represents - could you add a comment there instead?

available_timesteps = int(row_d["available_timesteps"])
valid_position = int(row_d["valid_position"])

if not augment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also make a note that augment here means temporal jittering? Am I right in understanding MIN_EDGE_BUFFER is just used to determine how much temporal jittering is allowed? In this case maybe it makes sense as a class attribute (or even as a default argument to this function)

@@ -306,7 +311,7 @@ def evaluate(
test_preds_np = test_preds_np >= self.threshold
prefix = f"{self.name}_{finetuned_model.__class__.__name__}"

catboost_preds = test_ds.df.worldcereal_prediction
catboost_preds = test_ds.df.worldcereal_prediction == 11
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for this change?

@@ -0,0 +1,1557 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this was an accidental commit?

@gabrieltseng
Copy link
Collaborator

Nice job @cbutsko - threw in my own comments too. +1 to @kvantricht 's point that the process_parquet is doing a lot, but the docstring is really helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants