Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input wise masks for mask gradients #4

Merged
merged 10 commits into from
Jul 19, 2024

Conversation

oliveradk
Copy link
Contributor

Adds support for computing input-wise mask gradients. Useful for e.g. doing anomaly detection using edge attribution scores

@oliveradk oliveradk marked this pull request as draft July 12, 2024 21:15
Copy link
Owner

@UFO-101 UFO-101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This is a nice idea, although I think executing it in a clean way is somewhat challenging.

I'm nervous about a couple of things here

  • It's not ideal to create a new patch_mask_batch parameter. It seems this should be integrated into the existing mask parameter. Changing the mask parameter for individual edges would no longer work because the main mask parameter would no longer be used.
  • But some existing functions will break if you just overwrite the mask parameter with the new batch dimension (because edge indices will no longer work.)
  • The patch_mode context manager shouldn't permanently alter the state of the model.

I would propose making this change in two steps

  1. Add a new dummy batch dimension (with size 1) on to the mask parameters that is always present. Update the edge index functions to work as before with this. The core logic in PatchWrapperImpl can stay mostly the same - we will just broadcast the batch dimension by default. (and update sample_hard_concrete to not add a new dimension.)
  2. Add a new context manager in graph_utils, set_mask_batch_size that temporarily adjusts the size of the mask batch dimension, allowing you to get the gradients for each batch element separately.

auto_circuit/data.py Outdated Show resolved Hide resolved
auto_circuit/utils/graph_utils.py Outdated Show resolved Hide resolved
auto_circuit/utils/patchable_model.py Outdated Show resolved Hide resolved
@oliveradk oliveradk marked this pull request as ready for review July 15, 2024 18:26
@oliveradk
Copy link
Contributor Author

Thanks for the feedback! I mostly implemented your suggestions, but instead of adding a dimension to patch_mask by default, only add the batch dimension if set_mask_batch_size is called - this minimizes the chance of any downstream bugs being introduced, and I don't think any of the edge indexing functionality is critical if the main use case is to collect attribution scores over batches (please let me know if I'm missing something crucial there)

Copy link
Owner

@UFO-101 UFO-101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! I think this a big improvement.

I'm still not super happy with this implementation as it breaks a bunch of other features. But I recognize that a more general solution is a bigger project which is not relevant to your usecase so I don't want to block this too much longer.

Main requests are to add a warning about these problems in the docstrings. And to improve the test a little bit.

if not mask_expanded:
mask = mask.repeat(batch_size, *([1] * mask.ndim))
else:
assert mask.size(0) == batch_size
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new parameter, it would be better to just check if it's already the correct shape and adjust accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the masks are 1d though, so unclear how to distinguish an expanded 1d mask and a 2d mask where the first dimension happens to equal the batch size

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, good point. Seems fine how it is then.

auto_circuit/utils/patch_wrapper.py Show resolved Hide resolved
auto_circuit/utils/graph_utils.py Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
tests/utils/test_instance_grads.py Outdated Show resolved Hide resolved
Copy link
Owner

@UFO-101 UFO-101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for this!

@UFO-101
Copy link
Owner

UFO-101 commented Jul 18, 2024

I will just do a couple of checks locally and then merge in next hour or so

@UFO-101
Copy link
Owner

UFO-101 commented Jul 19, 2024

@oliveradk Could you please fix merge conflicts and final comment? Then I will merge.

@oliveradk
Copy link
Contributor Author

Fixed merged conflicts and addressed comment (I had been running the test incorrectly and had to make some more tweaks)

@UFO-101 UFO-101 merged commit 3ad51b5 into UFO-101:main Jul 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants