-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using new parquet in train #104
base: main
Are you sure you want to change the base?
Conversation
…nd valid_date, with the possibility to augment the latter
…dded looping through individual parquets for performance reasons
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! Added already some comments. The process_parquet
is something that's for me hard to read just as code. Let's start by testing it thoroughly.
presto/dataset.py
Outdated
if (valid_position < cls.NUM_TIMESTEPS // 2) or ( | ||
valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2) | ||
): | ||
augment = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case we enter unconsciously in augmentation which might not be what we desire. Can't we put this logic inside if not augment
and in that case not choose the valid_position
as the center_point
but rather the the point that keeps valid_position
closest as possible to the center? Then it's always deterministic and we don't have to force going through the augmentation part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. tried to rewrite it here b2f1aa3
presto/dataset.py
Outdated
# make sure that month for encoding gets shifted according to | ||
# the selected timestep positions | ||
month = ( | ||
pd.to_datetime(row_d["start_date"]) + pd.DateOffset(months=timestep_positions[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also just add timestep_positions[0]
and not use the overhead of the pd.DateOffset
method. I don't think we're resilient in any case to situations where we're working with dekadal data. This is a potential risk if someone (including ourselves) is going over the dekadal track. Can we make it more universal here? Or if needed for now do a check and raise something, so nobody blindly takes this while it won't work as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. tried to account for that here ad05b3d
we will also need to think about renaming this month
thing (since it can be something else too), and also making sure that other relative timestep positions (valid_position
and timestep_ind
) are computed not as month, but in a more generic fashion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean this has to be tackled in process_parquet
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand what initial_start_date_position
actually means and why it cannot just be the month inferred from start_date
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, maybe I'm overthinking it...
here's an example:
-
we are working on a monthly basis, with start_date in October (hence,
month = 10
).
and we want to shift it 4 timesteps forward. we can just make(10 + 4)
and add the modulo part not to get bad month value% cls.NUM_TIMESTEPS
-
we are working on a dekadal basis. start date is the same. I assume that our timestep indices are not in month chunks, but in 10-day intervals. so, we have observations for ts0, ts1, ..., ts45 (for example). our
NUM_TIMESTEPS
variable should be set to 36 instead of 12 (like Giorgia did). thevalid_date
should also translate into 10-day chunk instead of month. so that we can select 36 timesteps around thevalid_position
.
now, we have ourtimestep_positions
returned in dekadal steps. so when we add 4 dekadal steps to a month value, it doesn't make sense. we need to add apples to apples to get the date that accounts for the shift.
and now we can take it's month and pass it to presto. I realize now that this particular step is missing in my implementation.
am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following most of it. But still, start_date
is a real date, from which we can infer what would normally be the start month being fed to Presto. Why do we need to translate it to a position? I agree that what we have to add to it due to the shift should account for the time "resolution". valid_position
should be the translation of valid_date
(which is irrespective of time resolution) to the position in the timesteps (which is depending on the time resolution). I might be overthinking it just as much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, maybe something like this can work:
step_converted_to_month = np.ceil(timestep_positions[0] * (365 // NUM_TIMESTEPS) / 30)
month = (pd.to_datetime(start_date).month + step_converted_to_month) % 12 - 1
it can be a little imprecise in some cases, not more than 1 month error.
but it's just a one-liner, so probably we can sacrifice this little bit of precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let's make a note that Giorgia can test this out properly when doing dekadal runs.
@@ -275,6 +326,7 @@ def __init__( | |||
dataframe = dataframe[(~dataframe.end_date.dt.year.isin(years_to_remove))] | |||
self.target_function = target_function if target_function is not None else self.target_crop | |||
self._class_weights: Optional[np.ndarray] = None | |||
self.augment = augment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add logging somewhere that shows that augmentation is enabled when initializing a dataset like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. addressed that here ae27e25
# add dummy value + rename stuff for compatibility with existing functions | ||
df["OPTICAL-B8A"] = NODATAVALUE | ||
|
||
# TODO: this needs to go away once the transition to new data is complete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, and can we immediately go to presto naming convention then as well?
…nyway False by default
…sing with augment parameter
…imports and failing tests
@cbutsko is this branch already part of croptype? You were training on new parquet format already? But maybe only locally? |
@kvantricht this branch is not yet the part of croptype. My plan is to sync it now with main, and then merge it directly into croptype, as you suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbutsko in croptype
branch this file is removed. What does that mean for the changes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged the functionality of paper_eval
into train.py
. I think in this branch it doesn't really matter. I just wanted to align this branch with main and then merge into croptype, resolving these conflict there. also, now after discussing ss training with Giorgia, it seems more feasible to create two separate files with more clear functions: something like train_self_supervised.py
and train_finetuned.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a good plan!
@@ -32,7 +34,7 @@ | |||
|
|||
|
|||
class WorldCerealBase(Dataset): | |||
_NODATAVALUE = 65535 | |||
# _NODATAVALUE = 65535 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbutsko why is this commented? And are changes below like the addition of get_timestep_positions
not part of the other branch?
_NODATAVALUE = 65535 | ||
# _NODATAVALUE = 65535 | ||
Y = "worldcereal_cropland" | ||
BAND_MAPPING = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these changes will all conflict with croptype branch I think?
…resto-worldcereal into using-new-parquet-in-train
@classmethod | ||
def get_timestep_positions(cls, row_d: Dict, augment: bool = False) -> List[int]: | ||
available_timesteps = int(row_d["available_timesteps"]) | ||
valid_position = int(row_d["valid_position"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment describing what valid_position
represents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see its created in utils.py
but its still not obvious to me what it represents - could you add a comment there instead?
available_timesteps = int(row_d["available_timesteps"]) | ||
valid_position = int(row_d["valid_position"]) | ||
|
||
if not augment: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also make a note that augment here means temporal jittering? Am I right in understanding MIN_EDGE_BUFFER
is just used to determine how much temporal jittering is allowed? In this case maybe it makes sense as a class attribute (or even as a default argument to this function)
@@ -306,7 +311,7 @@ def evaluate( | |||
test_preds_np = test_preds_np >= self.threshold | |||
prefix = f"{self.name}_{finetuned_model.__class__.__name__}" | |||
|
|||
catboost_preds = test_ds.df.worldcereal_prediction | |||
catboost_preds = test_ds.df.worldcereal_prediction == 11 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the motivation for this change?
@@ -0,0 +1,1557 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this was an accidental commit?
Nice job @cbutsko - threw in my own comments too. +1 to @kvantricht 's point that the process_parquet is doing a lot, but the docstring is really helpful. |
No description provided.