Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Mask Tensor Doesn't update #3

Closed
siwang2011 opened this issue Aug 19, 2021 · 1 comment
Closed

The Mask Tensor Doesn't update #3

siwang2011 opened this issue Aug 19, 2021 · 1 comment

Comments

@siwang2011
Copy link

Hi! Thanks for your work!
I want to use your program with my model and my data
but i find the mask tensor doesn't update
i use the model Informer https://github.com/zhouhaoyi/Informer2020 and my data is Stock Data
i use the MaskGroup GaussianBlur and the loss is MSE
T is 80, and i want to predict horizon 20 future
the loss is fixed at 0.000125
the Mask Tensor is almost 0.5
i dont know why
please help me
Thanks

@JonathanCrabbe
Copy link
Owner

Hi,

Thanks for using Dynamask. I have to say that I did not try to use Dynamask on Transformer-type architectures.
Although I do not see any obstruction to use Dynamask in this context, it might require a bit of work to ensure that everything is consistent. It goes without saying that pull-requests are welcome if there are nontrivial modifications that I did not take into account in my implementation.

That being said, allow me to propose a couple of easy tests that could help to detect if there is indeed a problem and, if this is the case, fix it.

  1. First you might want to ensure that your model is compliant with respect to the assumptions made by Dynamask. In particular, you need the model to be differentiable with respect to its input parameter. If the output of the model is a scalar, you can check that it is indeed the case by using Torch.Autograd. A sample code to adapt to your model goes as follow. If this prints a non-vanishing gradient, your model should be compatible with Dynamask.
import torch
input = torch.rand(T, D, requires_grad=True) # Tensor with T time steps and D features
prediction = model(input) # Forward pass
scalar = torch.mean(prediction) # Obtain a scalar from the output
scalar.backward() # Backpropagate 
print(input.grad) # The model's output should induce a gradient on the input
  1. A common mistake is to take a wrong order for the tensor indices. Dynamask requires the input to be of shape (T, D), where T is the number of time steps and D is the number of features. Note that there is no batch-index for this input tensor as Dynamask provides a mask one example at a time. If your model requires time series of the shape (B, D, T) where B is a batch index, you have to modify the black-box f that is fed to the mask. This can be done easily by defining an auxiliary black-box function:
import torch
input = torch.rand(T, D, requires_grad=True) # Tensor with T time steps and D features
model = ... # Model that takes time series input of the shape (B, D, T)
def f(x):
     x = x.unsqueeze(0) # (T, D) -> (1, T, D)
     x = x.transpose(1, 2) # (B, T, D) -> (B, D, T)
     out = model(x) # the model can be fed the modified tensor
     return out

# Now use f in the mask fitting:  mask.fit(f=f)
  1. As mentioned in the paper, Dynamask is a perturbation-based approach. In order to work, these methods require the input and the model to be sensitive to perturbations. A good way to understand this is by considering an extreme case: a time series that is constant will not be sensitive to a dynamic perturbation operator. I would suggest that you evaluate whether the time series and the model prediction are affected by a dynamic perturbations. This can be done by adapting the following code to your setting. If the perturbation operator have a small or no effect, you might want to play with their hyperparameters. For instance, you could increase the width of the Gaussian blur or the size of the window.
import torch
from attribution.perturbation import GaussianBlur

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input = torch.rand(T, D, device=device) # Tensor with T time steps and D features
model = ... # Model that take tensors of shape (T, D) as input (see point 2 above)
out = model(input) # Compute the model output
pert = GaussianBlur(device) # Gaussian blur or other perturbation operator
mask = 0.5 * torch.ones(input.shape, device=device) # Initial mask
input_pert = pert.apply(input, mask) # Perturb the input with the initial mask
out_pert = model(input_pert) # Apply the model to perturbed input

# Now compare the perturbed and unperturbed tensors. Are input and input_pert different? Are out and out_pert different?
  1. The error part of the mask objective is completely determined by the argument loss_function that is fed to the Mask.fit method. This loss function is used in the following line of the code. Please make sure that the loss function that you gave to the mask is compatible with this use. In particular, this loss function needs to be differentiable with respect to its first argument. Examples are provided in the following file.

  2. The regularization part of the objective might impede the reduction of its error part. To check if this is the case, please try to fit a mask by setting the regularization coefficients to zero. This can be done by setting the parameters size_reg_factor_init and time_reg_factor to zero when you call Mask.fit. In this case, since the objective only contains the error part, you should observe that the error decreases with the epoch. To check that, you can plot the learning curves by calling the Mask.plot_hist method after fitting the mask. You should look at the error part of the plot, the curve should deacrease with the epoch. If this is indeed the case, you can start increasing the size_reg_factor_init, time_reg_factor and size_reg_factor_dilation parameters of the fit method until you have learning curves that look like the following.
    mask-learning_curve-example

  3. In case none of the above point brings any insight, I suggest you try other hyperparameters. For instance, you could set the initial_mask_coeff parameter from the Mask.fit to 0.9 to start with a less extreme perturbation. Similarly, you could increase the learning_rate and/or the momentum of the same method to see if the learning curves look better.

Hope this helps.

Best,
Jonathan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants