Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite initialization #607

Merged
merged 26 commits into from
Jun 10, 2024
Merged

Rewrite initialization #607

merged 26 commits into from
Jun 10, 2024

Conversation

AkshitaB
Copy link
Contributor

@AkshitaB AkshitaB commented Jun 10, 2024

Simplifies our inscrutable initialization

  • IMPORTANT: currently, the implementation matches the old buggy values for init in several places. See below.
  • Removes init_weights with its complex if-else logic.
  • Adds init_normal which only takes the module, the std, and optionally a cutoff_factor.
  • std and cutoff_factor computation is now handled in each module's reset_parameters()
  • Adds unit tests for initialization.
  • Removes implementation for kaiming_normal and fan_in InitFnType as these aren't being used anywhere. Can be added later if needed.

Potential bugs found in initialization as a result of the refactoring (these will be fixed after feedback):

  • OLMoBlock.ff_out's normal initialization multiples std by an extra factor of 1 / math.sqrt(2 * self.config.n_layers. This potentially came from trying to incorporate full_megatron into the same function.
  • Hardcoded values: mitchell hardcodes a cutoff_factor of 3.0 (always truncated_normal_ with 3.0). full_megatron hardcodes a default cutoff_factor of 3.0 (truncated_normal_ with config.init_cutoff_factor or 3.0). Again, this may be a result of trying to incorporate multiple inits into the same function. Ideally, the cutoff_factor should always come from the configurable config.init_cutoff_factor; do we want to set always this value to 3.0 for mitchell and megatron?
  • Need clarification: Why do we scale the embedding with the following factor if scale_logits=True?
    emb_std_factor = (0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0
  • Additionally, in case of mitchell init, due to supplying the factor at multiple places in the old code, std ends up always being 0.5 when scale_logits=True!

@AkshitaB AkshitaB requested review from dirkgr and epwalsh June 10, 2024 06:56
Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No major concerns. I'm glad we're cleaning this up.

Why do we scale the embedding with the following factor if scale_logits=True?
emb_std_factor = (0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0.

This was another "trick" we heard works from someone else (not sure who).

olmo/model.py Outdated Show resolved Hide resolved
olmo/model.py Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
olmo/model.py Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
@AkshitaB
Copy link
Contributor Author

No major concerns. I'm glad we're cleaning this up.

Why do we scale the embedding with the following factor if scale_logits=True?
emb_std_factor = (0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0.

This was another "trick" we heard works from someone else (not sure who).

Wouldn't this make more sense if we did this when weight_tying was on? I'm trying to get a sense of intuition for some of these choices/tricks.

@epwalsh
Copy link
Member

epwalsh commented Jun 10, 2024

No major concerns. I'm glad we're cleaning this up.

Why do we scale the embedding with the following factor if scale_logits=True?
emb_std_factor = (0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0.

This was another "trick" we heard works from someone else (not sure who).

Wouldn't this make more sense if we did this when weight_tying was on? I'm trying to get a sense of intuition for some of these choices/tricks.

Yea I'm guessing that's the only scenario where we tried it? It might have come from PaLM.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

olmo/model.py Outdated Show resolved Hide resolved
@AkshitaB AkshitaB merged commit c2cedbc into main Jun 10, 2024
12 checks passed
@AkshitaB AkshitaB deleted the rewrite-init branch June 10, 2024 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants