-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input wise masks for mask gradients #4
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! This is a nice idea, although I think executing it in a clean way is somewhat challenging.
I'm nervous about a couple of things here
- It's not ideal to create a new
patch_mask_batch
parameter. It seems this should be integrated into the existingmask
parameter. Changing themask
parameter for individual edges would no longer work because the mainmask
parameter would no longer be used. - But some existing functions will break if you just overwrite the
mask
parameter with the new batch dimension (because edge indices will no longer work.) - The
patch_mode
context manager shouldn't permanently alter the state of the model.
I would propose making this change in two steps
- Add a new dummy batch dimension (with size
1
) on to themask
parameters that is always present. Update the edge index functions to work as before with this. The core logic in PatchWrapperImpl can stay mostly the same - we will just broadcast the batch dimension by default. (and updatesample_hard_concrete
to not add a new dimension.) - Add a new context manager in
graph_utils
,set_mask_batch_size
that temporarily adjusts the size of the mask batch dimension, allowing you to get the gradients for each batch element separately.
Thanks for the feedback! I mostly implemented your suggestions, but instead of adding a dimension to patch_mask by default, only add the batch dimension if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! I think this a big improvement.
I'm still not super happy with this implementation as it breaks a bunch of other features. But I recognize that a more general solution is a bigger project which is not relevant to your usecase so I don't want to block this too much longer.
Main requests are to add a warning about these problems in the docstrings. And to improve the test a little bit.
if not mask_expanded: | ||
mask = mask.repeat(batch_size, *([1] * mask.ndim)) | ||
else: | ||
assert mask.size(0) == batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a new parameter, it would be better to just check if it's already the correct shape and adjust accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some of the masks are 1d though, so unclear how to distinguish an expanded 1d mask and a 2d mask where the first dimension happens to equal the batch size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, good point. Seems fine how it is then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for this!
I will just do a couple of checks locally and then merge in next hour or so |
@oliveradk Could you please fix merge conflicts and final comment? Then I will merge. |
Fixed merged conflicts and addressed comment (I had been running the test incorrectly and had to make some more tweaks) |
Adds support for computing input-wise mask gradients. Useful for e.g. doing anomaly detection using edge attribution scores