Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-training on a subset of the original input channels #26

Open
kvantricht opened this issue Oct 20, 2023 · 14 comments
Open

Pre-training on a subset of the original input channels #26

kvantricht opened this issue Oct 20, 2023 · 14 comments

Comments

@kvantricht
Copy link
Contributor

When using construct_single_presto_input, the code conveniently handles the normalization of the inputs and construction of the mask. If certain inputs (bands) are missing, the respective mask values are automatically set to 1. However, there seems to be no way to deal with certain missing timesteps in the inputs. Imagine a monthly compositing of Sentinel-2 resulting in no valid observations for some month. Either we can deal with it by linearly interpolating the missing values, but it seems Presto should actually be able to natively deal with missing timesteps by setting the respective mask value to 1.

At the moment, the only way to do it is by keeping track of missing value positions in the original inputs and after the call to construct_single_presto_input setting the mask of these positions to 1. Would there be a more convenient way of doing this? Thinking of certain no-data values that can be treated by this method as missing and setting the mask in correspondance.

Specific side note on automatic computation of NDVI: we were testing with NaN inputs for S2 to see how the code behaves. Interestingly, this line actually makes up a valid NDVI value of 0 in x when the inputs are invalid.

@kvantricht
Copy link
Contributor Author

kvantricht commented Oct 24, 2023

@gabrieltseng I should come back to this. It looks to me now that Presto expects the same elements to be masked across an entire batch? https://github.com/nasaharvest/presto/blob/main/presto/presto.py#L369

Or do I misunderstand this? Still thinking about the case where in some samples we could miss some values. This results in a batch where not all elements are equally masked. Is this something Presto cannot handle? FYI, batch_size=1 does work.

Side information: I'm now detecting which timestep/band combinations are no data (e.g. when we don't interpolate missing monthly values in optical) and I'm manually setting corresponding mask values to 1. For any batch size greater than 1, there can be samples which have different masked elements, and then the Presto code fails.

@gabrieltseng
Copy link
Collaborator

gabrieltseng commented Oct 24, 2023

Hi @kvantricht ,

You are right - we currently don't support masking interim timesteps in construct_single_presto_input (although the model itself will have no problem with this).

The reason for this is an artifact of how we export the data (e.g. since we select the least cloudy pixel for Sentinel-2, we should never have a missing pixel, just the least cloudy one). However, I can update this function to make this easier.

With respect to the consistent number of masked tokens within a batch, that is correct. The reason for this is that within the model, we end up passing a tensor of shape [batch_size, num_unmasked_tokens, token_dimension] around. This requires num_unmasked_tokens to be consistent for all items in the batch.

@kvantricht
Copy link
Contributor Author

kvantricht commented Oct 25, 2023

@gabrieltseng

With respect to the consistent number of masked tokens within a batch, that is correct.

Following up on this, we're trying to create our own DataLoader in a first attempt to finetune pretrained Presto on our data. We just start from a dataframe where every row is a sample so we have to refactor a bit to make it compatible with the way Presto needs the data.

I'm a bit confused about this step. It seems to call the make_mask method, returning a mask with dimensionality (12, 17). So I'm wondering: is one and the same generated mask applied to an entire batch? Or is every sample in the batch masked differently?

In case of the former, I'm not sure how currently the code applies this mask to the entire batch. And in case of the latter, it seems that make_mask always generates slightly different amounts of masked elements which violates the requirement of having same number of masked tokens within a batch.

Would you be able to clarify?

@gabrieltseng
Copy link
Collaborator

is one and the same generated mask applied to an entire batch? Or is every sample in the batch masked differently?

Every sample in the batch is masked differently. make_mask is called from the MaskParams object, which has a ratio attribute - this is passed to make_mask, and defines how many tokens should be masked. We keep this ratio constant throughout training.

However, this is just the case for pre-training; our finetuning code is available here. This code assumes a static mask for the whole task (of shape [timesteps, dimensions], which we then expand so that is has shape [batch_size, timesteps, dimensions].

@kvantricht
Copy link
Contributor Author

We keep this ratio constant throughout training.

This is indeed what I thought, so I don't really understand why every call to make_mask with a fixed mask_ratio of 0.75 yields different amounts of masked values?

from presto.dataops.masking import make_mask

for i in range(10):
    print(f'Number of masked elements: {make_mask("group_bands", mask_ratio=0.75)[0].sum()}')

Number of masked elements: 148
Number of masked elements: 135
Number of masked elements: 156
Number of masked elements: 150
Number of masked elements: 142
Number of masked elements: 129
Number of masked elements: 126
Number of masked elements: 125
Number of masked elements: 160
Number of masked elements: 176

However, this is just the case for pre-training; our finetuning code is available here. This code assumes a static mask for the whole task (of shape [timesteps, dimensions], which we then expand so that is has shape [batch_size, timesteps, dimensions].

Gotcha, but I'm puzzled by things like num_outputs, which seems to tell me that instances of FineTuningModel are actually meant to have a classification head to finetune Presto on a downstream task. What I was trying to do is "continue" Presto pre-training a bit in a self-supervised way without labels on a lot samples from WorldCereal, to get the model used to some different preprocessing techniques that were used. You think it doesn't make sense?

@gabrieltseng
Copy link
Collaborator

mask_ratio ensures the same number of tokens are masked. Since a token represents a grouping of a certain number of input variables, this can result in a different number of masked values which would result in the same number of masked tokens.

For example, one token is S1 VV and VH (2 values) and one token is S2 RGB (3 values). So if I were to only mask one token, I might sometimes mask 2 values or 3 values.

The FineTuningModel replaces the decoder with a classification / regression head for finetuning. If you plan on continuing to pre-train the model, then you should use the full Presto model (i.e. what you obtain if you call encoder_decoder = Presto.load_pretrained()).

@kvantricht
Copy link
Contributor Author

Right, but then it brings me back to the original issue where the mask tensor holds original mask values (so not tokens yet) which like you say can have a different quantity for different samples in a batch. Then the part below fails:

File ~/git/presto/presto/presto.py:372, in Encoder.mask_tokens(x, mask)
    367 @staticmethod
    368 def mask_tokens(x, mask):
    369     summed = mask.sum(
    370         dim=(1, 2)
    371     )  # summed tells me the number of masked elements per batch idx
--> 372     assert summed.max() == summed.min(), f"{summed.max()}, {summed.min()}"
    374     batch_size = x.shape[0]
    375     removed_elements_per_batch = int(summed.max() / mask.shape[2])

AssertionError: 11520, 10368

And yes, at the moment we work with the full Presto model.

@gabrieltseng
Copy link
Collaborator

gabrieltseng commented Oct 25, 2023

Then the part below fails:

Could you share the code snippet you use to generate the masks? Within a batch, the same number of tokens must be passed / masked (this is what the assert statament was designed to check for).

As you noted before, using a batch_size == 1 is an easy fix for this.

@kvantricht
Copy link
Contributor Author

Okay yeah indeed, some example code is more practical to discuss. I made a notebook here which takes a sample WorldCereal dataframe and attempts to transform it into something Presto accepts: https://github.com/kvantricht/presto/blob/worldcereal/notebooks/presto_pretrain_finetune.ipynb

As you can see at the end of the notebook, the error messages is about the masks. Hopefully this helps to identify where I'm doing something awfully wrong!

@kvantricht
Copy link
Contributor Author

FYI, when taking batch_size of 1 and putting all dynamic world data to 0 (currently it gets set to 9 which gives an out of bounds error as Presto decoder expects DW classes between 0 and 8 it seems), then the training loops starts, so technically, we're getting close.

@gabrieltseng
Copy link
Collaborator

Okay - I think the cause of the error in the dataloader is that the masking function assumes nothing is masked. However, when using construct_single_presto_input some things are masked (e.g. dynamic world), so the calculations in the masking function to mask a certain number of tokens become incorrect.

This requires a re-write of the masking function, but shouldn't be too complicated. I can give this a shot in the next few days.

@kvantricht
Copy link
Contributor Author

Not exactly sure what you mean. Indeed, I use construct_single_presto_input now, but the mask that is returned from it I don't use. It's just the x that goes into Presto. Besides, make_mask works independently of the data.

@gabrieltseng
Copy link
Collaborator

@kvantricht , I opened a PR into the branch you shared (kvantricht#1) which shows those changes.

I've added a description there of exactly the changes I made

@gabrieltseng gabrieltseng changed the title Handle missing timesteps in construct_single_presto_input Pre-training on a subset of the original input channels Oct 26, 2023
@kvantricht
Copy link
Contributor Author

Amazing Gabi, it runs here now as well with batch size 4096. I will keep you posted on the outcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants