Skip to content

Commit

Permalink
fix logic and cite partial rotary embeddings from Wang & Komatsuzaki …
Browse files Browse the repository at this point in the history
…et al
  • Loading branch information
lucidrains committed Apr 14, 2022
1 parent 4f99e31 commit f2d2815
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions retro_pytorch/retro_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ def __init__(
super().__init__()
self.layers = nn.ModuleList([])

rotary_emb_dim = max(dim_head // 2, MIN_DIM_HEAD)
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/

rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)
Expand Down Expand Up @@ -377,7 +380,10 @@ def __init__(
super().__init__()
self.layers = nn.ModuleList([])

rotary_emb_dim = max(dim_head // 2, MIN_DIM_HEAD)
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/

rotary_emb_dim = min(dim_head, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'retro-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'RETRO - Retrieval Enhanced Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f2d2815

Please sign in to comment.