Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure GPU works #130

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/xla
Submodule xla updated 52 files
+9 −7 .github/workflows/_build_plugin.yml
+11 −10 .github/workflows/_build_torch_with_cuda.yml
+10 −19 .github/workflows/_build_torch_xla.yml
+1 −1 .github/workflows/_docs.yml
+0 −25 .github/workflows/_get_torch_commit.yml
+13 −33 .github/workflows/_test.yml
+12 −30 .github/workflows/_test_requiring_torch_cuda.yml
+8 −10 .github/workflows/_tpu_ci.yml
+8 −2 .github/workflows/build_and_test.yml
+86 −0 .github/workflows/setup/action.yml
+16 −2 README.md
+6 −2 benchmarks/benchmark_model.py
+19 −6 benchmarks/torchbench_model.py
+37 −79 docs/fori_loop.md
+82 −0 docs/plugins.md
+1 −1 docs/source/index.rst
+0 −554 docs/spmd.md
+150 −0 docs/spmd_advanced.md
+83 −0 docs/spmd_basic.md
+125 −0 docs/spmd_distributed_checkpoint.md
+3 −2 examples/decoder_only_model.py
+157 −0 experimental/torch_xla2/docs/support_a_new_model.md
+7 −0 experimental/torch_xla2/examples/eager_mode.py
+49 −0 experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py
+0 −9 experimental/torch_xla2/test/test_ops.py
+19 −8 experimental/torch_xla2/torch_xla2/__init__.py
+1 −0 experimental/torch_xla2/torch_xla2/config.py
+49 −0 experimental/torch_xla2/torch_xla2/ops/jaten.py
+2 −2 experimental/torch_xla2/torch_xla2/ops/jtorch.py
+25 −11 experimental/torch_xla2/torch_xla2/tensor.py
+1 −0 infra/ansible/config/apt.yaml
+1 −1 infra/ansible/config/env.yaml
+3 −0 plugins/cpu/README.md
+2 −2 plugins/cpu/pyproject.toml
+3 −0 plugins/cuda/README.md
+9 −10 test/debug_tool/test_pt_xla_debug.py
+35 −5 test/dynamo/test_dynamo.py
+3 −2 test/run_tests.sh
+1 −48 test/spmd/test_dynamo_spmd.py
+2 −2 test/spmd/test_sharding_strategies.py
+6 −4 test/spmd/test_xla_sharding.py
+0 −106 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
+16 −17 test/test_metrics.py
+116 −0 test/test_while_loop.py
+43 −0 test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py
+1 −1 test/tpu/run_tests.sh
+2 −0 torch_xla/_internal/tpu.py
+56 −7 torch_xla/core/xla_model.py
+20 −15 torch_xla/csrc/init_python_bindings.cpp
+3 −0 torch_xla/csrc/runtime/pjrt_registry.cc
+7 −14 torch_xla/distributed/spmd/xla_sharding.py
+112 −45 torch_xla/experimental/fori_loop.py
3 changes: 2 additions & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ git submodule update --init --recursive
pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
6 changes: 3 additions & 3 deletions install_everything_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ pip show tensorboard && pip uninstall -y tensorboard
pip show tensorflow-text && pip uninstall -y tensorflow-text
pip show torch_xla2 && pip uninstall -y torch_xla2

pip install flax==0.8.3
pip install -U "jax[cuda12]==0.4.28"
pip install flax==0.8.4
pip install tensorflow-text
qihqi marked this conversation as resolved.
Show resolved Hide resolved
pip install tensorflow

pip install ray[default]==2.22.0
# torch cpu
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
pip install safetensors colorama coverage humanize

git submodule update --init --recursive
pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[cuda12]==0.4.30
pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
3 changes: 2 additions & 1 deletion run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import os
from typing import Sequence

# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
import jetstream_pt
from absl import app, flags
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
Expand Down
24 changes: 24 additions & 0 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from jetstream_pt.third_party.llama import model_original
from jetstream_pt.third_party.gemma import model_original as gemma_orig
from jetstream_pt.third_party.gemma import model as gemma
from jetstream_pt.third_party.mixtral import model as mixtral
from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt import torchjax
from jetstream_pt import layers
from jetstream_pt import cache_manager
Expand Down Expand Up @@ -360,6 +362,28 @@ def test_transformer(self):
print("Transformer: Diff norm", (result_torch - expected_out).norm())
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))

def test_mixtral_moe(self):
config = mixtral_config.ModelArgs()
config.intermediate_size = 16
config.dim = 16
m = mixtral.ConditionalFeedForward(config)
# random init
states = m.state_dict()
for k, v in states.items():
states[k].normal_()
m.load_state_dict(states, assign=True)

seqlen = 3
num_expert = 8
num_active_expert = 2
x = torch.randn(seqlen, config.dim)
exp_index = torch.randint(0, num_expert, (seqlen, num_active_expert))

res1 = m.forward_for_short_seq_len(x, exp_index)
res2 = m.forward_for_long_seq_len(x, exp_index)

torch.testing.assert_close(res1, res2)


if __name__ == "__main__":
unittest.main()
Loading