Skip to content

Commit

Permalink
Add a test for Moe layer, modify install script to make GPU run
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jun 18, 2024
1 parent c3293c4 commit b0f95f3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
2 changes: 1 addition & 1 deletion install_everything_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pip install tensorflow

pip install ray[default]==2.22.0
# torch cpu
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
pip install safetensors colorama coverage humanize

Expand All @@ -39,3 +38,4 @@ pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[cuda12]==0.4.30
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
10 changes: 5 additions & 5 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jetstream_pt.third_party.llama import model_original
from jetstream_pt.third_party.gemma import model_original as gemma_orig
from jetstream_pt.third_party.gemma import model as gemma
from jetstream_pt.third_party.mixtral import model as mixtral
from jetstream_pt.third_party.mixtral import model as mixtral
from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt import torchjax
from jetstream_pt import layers
Expand Down Expand Up @@ -370,19 +370,19 @@ def test_mixtral_moe(self):
# random init
states = m.state_dict()
for k, v in states.items():
states[k].normal_()
states[k].normal_()
m.load_state_dict(states, assign=True)

seqlen = 3
num_expert = 8
num_expert = 8
num_active_expert = 2
x = torch.randn(10, config.dim)
x = torch.randn(seqlen, config.dim)
exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert))

res1 = m.forward_for_short_seq_len(x, exp_index)
res2 = m.forward_for_long_seq_len(x, exp_index)

torch.testing.assert_close(res1, res2, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(res1, res2)


if __name__ == "__main__":
Expand Down

0 comments on commit b0f95f3

Please sign in to comment.