Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor IJEPA to use timm. #1612

Merged
merged 6 commits into from
Jul 30, 2024
Merged

Refactor IJEPA to use timm. #1612

merged 6 commits into from
Jul 30, 2024

Conversation

radiradev
Copy link
Contributor

@radiradev radiradev commented Jul 26, 2024

Changes

This PR adresses #1367. I have refactored IJEPA to use timm. I have tried to stay closer to the original implementation, and also added typing. There might be some structural changes needed - for instance the apply_masks function should probably be moved to utils? Any suggestions on how to improve this are welcome.

Also like the MAE timm implementation, I have created a separate file, instead of directly replacing the torchvision implementation. I assume once this is benchmarked the plan would be to completely replace it.

How was it tested?

Unit tests for the predictor, encoder and backbone classes.

Copy link

codecov bot commented Jul 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.61%. Comparing base (78f59fc) to head (9708f36).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1612      +/-   ##
==========================================
+ Coverage   85.49%   85.61%   +0.11%     
==========================================
  Files         147      148       +1     
  Lines        6281     6333      +52     
==========================================
+ Hits         5370     5422      +52     
  Misses        911      911              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@guarin
Copy link
Contributor

guarin commented Jul 26, 2024

Hi! Thanks a lot for this extensive PR, it looks really well made!

Looking at the code I see many parallels to our MaskedVisionTransformer implementation. I didn't have time yet to go through the full PR but I have a suspicion that IJEPABackboneTIMM and IJEPAEncoderTIMM are compatible with our implementation of the MaskedVisionTransformer (see here and here). Do you think it would be possible to use MaskedVisionTransformer directly instead of IJEPABackboneTIMM or am I missing something? If possible this simplify the code a lot.

What I imagine is something like this:

target_encoder = MaskedVisionTransformer()
context_encoder = MaskedVisionTransformer()
predictor = IJEPAPredictorTIMM()

# This encodes all patches.
target = target_encoder.encode(images)
# This encodes only the unmasked context patches.
context = context_encoder.encode(images, idx_keep=idx_keep)

prediction = predictor(context, masks_x, masks)
target = get_targets_at_masks(target, masks_x, masks)
loss(predictions, target)

@radiradev
Copy link
Contributor Author

Hi @guarin. I think you are correct and we can reuse the MaskedVisionTransformer making this PR a lot smaller. I have updated it with your suggested changes. Also added drop_path_rate for IJepaPredictor.

Copy link
Contributor

@guarin guarin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks so much!

@guarin guarin mentioned this pull request Jul 29, 2024
3 tasks
@guarin
Copy link
Contributor

guarin commented Jul 29, 2024

@guarin guarin enabled auto-merge (squash) July 30, 2024 11:39
@guarin guarin merged commit 1bde34e into lightly-ai:master Jul 30, 2024
10 checks passed
@radiradev radiradev deleted the ijepa_timm branch July 30, 2024 13:52
@guarin guarin mentioned this pull request Aug 16, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants