Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MACE subclasses - reduce code duplication & clearer logic #97

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from

Conversation

stenczelt
Copy link
Contributor

@stenczelt stenczelt commented Apr 2, 2023

Supersedes #65, implements the JIT part of #95 as well for free

Refactored subclasses of MACE:

  • ScaleShiftMACE: Only scales & shifts MACE
  • AtomicDipolesMACE: Only defined dipole calculation capabilities
  • EnergyDipolesMACE: Energy & dipole calculation

Main changes:

  • __init__ is basically the same, there is only a little difference, which is handled with
    • readout block classes are kept in as class members
    • irreps of the last layer is computed with a function (overwritten by subclasses)
  • parts of the forward pass calculation are separated into internal functions
  • MACE main architecture, energy calculation, dipole calculation are extracted into respective classes which are combined to obtain the model classes (much easier to re-use by others as well I suspect)
  • docstrings were added to a number of places

- extracted the layer calculations
- "ScaleShift" applied to non-duplicated layer calculation code
- DipolesMACE & EnergyDipolesMACE inherit from MACE
- EnergyDipolesMACE accepting dict for forward pass + closer match to MACE class
- DipoleOnly versions of blocks added explicitly (to be refactored)
- JIT for DipolesMACE & EnergyDipolesMACE
@stenczelt
Copy link
Contributor Author

@davkovacs there's a more general refactor in https://github.com/stenczelt/mace/tree/ENH/refactor-model-v3 where I have separated out the backbone of the MACE model (agnostic of what you're calculating) and the quantity-specific readout blocks and gradient calculations.
That one is a somewhat larger refactor, though I think it's quite neat. Shall I update this PR to that instead?

@davkovacs
Copy link
Collaborator

I would say let’s have one refactor and test that thoroughly

@stenczelt
Copy link
Contributor Author

Roger that @davkovacs !

interaction_cls: Type[InteractionBlock],
interaction_cls_first: Type[InteractionBlock],
num_interactions: int,
num_elements: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of elements is redundant here. I propose removing it. ok?

hidden irreducible representations, basically the size of the layer features
and hence direct control on the size of the model
MLP_irreps
avg_num_neighbors
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docstrings are duplicated across the classes exposed to the user as well. I think this is fine, so whichever one you are using you see the docstring directly. If you disagree feel free to remove the redundant ones and perhaps keep them for the main one

atomic_numbers: List[int],
correlation: int,
gate: Optional[Callable],
radial_MLP: Optional[List[int]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not changed this for the sake of consistency, method parameters are conventionally lower case, so I propose changing these MLP parameters to lowercase



@compile_mode("script")
class AtomicDipolesMACE(MaceCoreModel, DipoleModelMixin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this called "AtomicDipolesMACE"? Should it not just be DipolesMACE


def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_force: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this class cannot calculate forces, why don't we remove the associated parameters? Are they intended for compatibility with wrappers?

self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, parameters as traps only. Can understand if it is compatibility and perhaps will be implemented later on.

return output

@compile_mode("script")
class ScaleShiftEnergyDipoleMACE(EnergyDipolesMACE, ScaleShiftEnergyModelMixin):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just an extra, which was easy to add due to the separated logic of dipoles & energy

assert key in output


def test_scaled_and_shifted(dipole_model_config, data_batch_1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the tests were enormously useful while refactoring, there is more to test on all these classes. Should be somewhat easier now.

@stenczelt
Copy link
Contributor Author

@davkovacs @ilyes319 I have added my comments relating to what I have done, you might be able to answer these or make the appropriate judgement calls. Please let me know if you need clarification on anything or want changes which I should make

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants