From dbebc7c02432932a658cf4feb83080bec9585059 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Thu, 19 Sep 2024 13:14:18 -0400 Subject: [PATCH] Batch resizing (#286) * Add utility to help manage batch resizer * Add scaling config * cleanrljax pins * Update scaling files * Update batch sizer model * Latest batch size model * Update pins --------- Co-authored-by: pierre.delaunay --- .pin/constraints-cuda-torch.txt | 64 ++--- benchmarks/brax/main.py | 3 + benchmarks/brax/requirements.cuda.txt | 20 +- benchmarks/brax/voirfile.py | 4 +- benchmarks/diffusion/requirements.cuda.txt | 16 +- benchmarks/dinov2/requirements.cuda.txt | 14 +- benchmarks/flops/requirements.cuda.txt | 12 +- benchmarks/geo_gnn/dev.yaml | 2 +- benchmarks/geo_gnn/main.py | 17 +- benchmarks/geo_gnn/requirements-pre.cuda.txt | 12 +- benchmarks/geo_gnn/requirements.cuda.txt | 12 +- benchmarks/huggingface/requirements.cuda.txt | 14 +- benchmarks/lightning/main.py | 8 +- benchmarks/lightning/requirements.cuda.txt | 12 +- benchmarks/llama/requirements.cuda.txt | 14 +- benchmarks/llava/requirements.cuda.txt | 14 +- benchmarks/llm/requirements.cuda.txt | 20 +- benchmarks/llm/requirements.in | 1 + benchmarks/purejaxrl/dqn.py | 46 +++- benchmarks/purejaxrl/ppo.py | 4 + benchmarks/purejaxrl/requirements.cuda.txt | 26 +-- benchmarks/purejaxrl/voirfile.py | 2 +- benchmarks/recursiongfn/requirements.cuda.txt | 36 +-- benchmarks/rlhf/requirements.cuda.txt | 18 +- benchmarks/timm/requirements.cuda.txt | 14 +- benchmarks/torchatari/requirements.cuda.txt | 16 +- benchmarks/torchvision/requirements.cuda.txt | 12 +- benchmarks/torchvision/voirfile.py | 2 +- .../torchvision_ddp/requirements.cuda.txt | 12 +- benchmarks/vjepa/main.py | 2 +- benchmarks/vjepa/requirements.cuda.txt | 16 +- benchmate/benchmate/jaxmem.py | 30 +++ benchmate/benchmate/monitor.py | 7 +- config/base.yaml | 55 ++++- config/scaling.yaml | 221 +++++++++++++++--- config/standard.yaml | 10 +- milabench/cli/gather.py | 1 + milabench/cli/list.py | 56 +++++ milabench/sizer.py | 19 +- scripts/article/run_cuda.sh | 21 +- 40 files changed, 626 insertions(+), 259 deletions(-) create mode 100644 benchmate/benchmate/jaxmem.py create mode 100644 milabench/cli/list.py diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt index 2717ed4ef..96876ac75 100644 --- a/.pin/constraints-cuda-torch.txt +++ b/.pin/constraints-cuda-torch.txt @@ -70,7 +70,7 @@ blobfile==3.0.0 # torchtune blosc2==2.7.1 # via tables -botorch==0.11.3 +botorch==0.12.0 # via -r benchmarks/recursiongfn/requirements.in braceexpand==0.1.7 # via @@ -129,7 +129,7 @@ decorator==5.1.1 # via tensorflow-probability decord==0.6.0 # via -r benchmarks/vjepa/requirements.in -diffusers[torch]==0.30.2 +diffusers[torch]==0.30.3 # via -r benchmarks/diffusion/requirements.in dill==0.3.8 # via @@ -179,7 +179,7 @@ fairscale==0.4.13 # -r benchmarks/llm/requirements.txt farama-notifications==0.0.4 # via gymnasium -filelock==3.16.0 +filelock==3.16.1 # via # blobfile # datasets @@ -241,7 +241,7 @@ giving==0.4.3 # voir glfw==2.7.0 # via mujoco -gpytorch==1.12 +gpytorch==1.13 # via # -r benchmarks/recursiongfn/requirements.in # botorch @@ -267,7 +267,7 @@ gymnax==0.0.8 # -r benchmarks/purejaxrl/requirements.in hjson==3.1.0 # via argklass -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -r benchmarks/timm/requirements.in # accelerate @@ -301,7 +301,7 @@ isort==5.13.2 # via pylint itsdangerous==2.2.0 # via flask -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -r benchmarks/brax/requirements.in # -r benchmarks/purejaxrl/requirements.in @@ -318,11 +318,11 @@ jax[cuda12]==0.4.31 # optax # orbax-checkpoint # rlax -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # brax # chex @@ -338,8 +338,10 @@ jaxlib==0.4.31 # rlax jaxopt==0.8.3 # via brax -jaxtyping==0.2.34 - # via linear-operator +jaxtyping==0.2.19 + # via + # gpytorch + # linear-operator jinja2==3.1.4 # via # brax @@ -357,7 +359,7 @@ lightning-utilities==0.11.7 # lightning # pytorch-lightning # torchmetrics -linear-operator==0.5.2 +linear-operator==0.5.3 # via # botorch # gpytorch @@ -393,17 +395,18 @@ mpmath==1.3.0 # via # botorch # gpytorch + # linear-operator # sympy msgpack==1.1.0 # via # blosc2 # flax # orbax-checkpoint -mujoco==3.2.2 +mujoco==3.2.3 # via # brax # mujoco-mjx -mujoco-mjx==3.2.2 +mujoco-mjx==3.2.3 # via brax multidict==6.1.0 # via @@ -438,7 +441,6 @@ numpy==1.26.4 # -r benchmarks/vjepa/requirements.in # accelerate # blosc2 - # botorch # brax # chex # contourpy @@ -457,6 +459,7 @@ numpy==1.26.4 # jax # jaxlib # jaxopt + # jaxtyping # matplotlib # ml-dtypes # mujoco @@ -557,7 +560,7 @@ optax==0.2.3 # flax optree==0.12.1 # via envpool -orbax-checkpoint==0.6.3 +orbax-checkpoint==0.6.4 # via # brax # flax @@ -601,7 +604,7 @@ pillow==10.4.0 # navix # rdkit # torchvision -platformdirs==4.3.3 +platformdirs==4.3.6 # via # black # pylint @@ -610,7 +613,7 @@ pluggy==1.5.0 # via pytest portalocker==2.10.1 # via iopath -protobuf==5.28.1 +protobuf==5.28.2 # via # orbax-checkpoint # tensorboard @@ -765,11 +768,11 @@ six==1.16.0 # tensorflow-probability smmap==5.0.1 # via gitdb -submitit==1.5.1 +submitit==1.5.2 # via # -r benchmarks/dinov2/requirements.in # -r benchmarks/vjepa/requirements.in -sympy==1.13.2 +sympy==1.13.3 # via torch tables==3.10.1 # via -r benchmarks/recursiongfn/requirements.in @@ -861,8 +864,8 @@ torch-sparse==0.6.18+pt24cu121 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in -torchao==0.3.1+cu121 - # via torchtune +torchao==0.5.0+cu121 + # via -r benchmarks/llm/requirements.in torchcompat==1.1.4 # via # -c .pin/../constraints/cuda.txt @@ -876,7 +879,7 @@ torchmetrics==1.4.2 # -r benchmarks/dinov2/requirements.in # lightning # pytorch-lightning -torchtune==0.2.1+cu121 +torchtune==0.3.0+cu121 # via -r benchmarks/llm/requirements.in torchvision==0.19.0+cu121 # via @@ -920,18 +923,17 @@ trimesh==4.4.9 # mujoco-mjx triton==3.0.0 # via torch -trl==0.10.1 +trl==0.11.0 # via -r benchmarks/rlhf/requirements.in -typeguard==2.13.3 - # via - # jaxtyping - # linear-operator +typeguard==4.3.0 + # via jaxtyping types-protobuf==5.27.0.20240907 # via envpool typing-extensions==4.12.2 # via # astroid # black + # botorch # brax # chex # envpool @@ -941,6 +943,7 @@ typing-extensions==4.12.2 # gymnasium # huggingface-hub # iopath + # jaxtyping # lightning # lightning-utilities # multidict @@ -952,8 +955,9 @@ typing-extensions==4.12.2 # submitit # tables # torch + # typeguard # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -r benchmarks/torchatari/requirements.in # navix @@ -988,7 +992,7 @@ voir==0.2.19 # -r benchmarks/torchvision/requirements.in # -r benchmarks/torchvision_ddp/requirements.in # -r benchmarks/vjepa/requirements.in -wandb==0.18.0 +wandb==0.18.1 # via # -r benchmarks/recursiongfn/requirements.in # navix diff --git a/benchmarks/brax/main.py b/benchmarks/brax/main.py index 572ce739c..6625bcd04 100644 --- a/benchmarks/brax/main.py +++ b/benchmarks/brax/main.py @@ -85,6 +85,9 @@ def run(): args = parser.parse_args() + # args.num_envs = (args.batch_size * args.num_minibatches) + + train( environment=envs.get_environment(env_name=args.env), num_timesteps=args.num_timesteps, diff --git a/benchmarks/brax/requirements.cuda.txt b/benchmarks/brax/requirements.cuda.txt index aa883171c..89ebe8840 100644 --- a/benchmarks/brax/requirements.cuda.txt +++ b/benchmarks/brax/requirements.cuda.txt @@ -77,7 +77,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -133,7 +133,7 @@ itsdangerous==2.2.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flask -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt @@ -145,15 +145,15 @@ jax[cuda12]==0.4.31 # mujoco-mjx # optax # orbax-checkpoint -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -205,12 +205,12 @@ msgpack==1.1.0 # -c .pin/../.pin/constraints-cuda-torch.txt # flax # orbax-checkpoint -mujoco==3.2.2 +mujoco==3.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax # mujoco-mjx -mujoco-mjx==3.2.2 +mujoco-mjx==3.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -324,7 +324,7 @@ optax==0.2.3 # -c .pin/../.pin/constraints-cuda-torch.txt # brax # flax -orbax-checkpoint==0.6.3 +orbax-checkpoint==0.6.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -341,7 +341,7 @@ pillow==10.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax -protobuf==5.28.1 +protobuf==5.28.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # orbax-checkpoint @@ -395,7 +395,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # ml-collections -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/brax/voirfile.py b/benchmarks/brax/voirfile.py index fce6f66d0..3397dcb31 100644 --- a/benchmarks/brax/voirfile.py +++ b/benchmarks/brax/voirfile.py @@ -20,10 +20,10 @@ class Config: skip: int = 5 # Number of rates to log before stopping - stop: int = 20 + stop: int = 60 # Number of seconds between each gpu poll - gpu_poll: int = 3 + gpu_poll: int = 1 @configurable diff --git a/benchmarks/diffusion/requirements.cuda.txt b/benchmarks/diffusion/requirements.cuda.txt index 6a062a7a0..34a92c65d 100644 --- a/benchmarks/diffusion/requirements.cuda.txt +++ b/benchmarks/diffusion/requirements.cuda.txt @@ -64,7 +64,7 @@ datasets==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in -diffusers[torch]==0.30.2 +diffusers[torch]==0.30.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in @@ -77,7 +77,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -106,7 +106,7 @@ hjson==3.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -127,19 +127,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -363,7 +363,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/dinov2/requirements.cuda.txt b/benchmarks/dinov2/requirements.cuda.txt index aef36dbf3..9b3940ff2 100644 --- a/benchmarks/dinov2/requirements.cuda.txt +++ b/benchmarks/dinov2/requirements.cuda.txt @@ -30,7 +30,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -53,19 +53,19 @@ iopath==0.1.10 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/dinov2/requirements.in # fvcore -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -246,11 +246,11 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -submitit==1.5.1 +submitit==1.5.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/dinov2/requirements.in -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/flops/requirements.cuda.txt b/benchmarks/flops/requirements.cuda.txt index afb7ff130..e529152e3 100644 --- a/benchmarks/flops/requirements.cuda.txt +++ b/benchmarks/flops/requirements.cuda.txt @@ -26,7 +26,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -44,19 +44,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchcompat -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -217,7 +217,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/geo_gnn/dev.yaml b/benchmarks/geo_gnn/dev.yaml index 6f261c895..67cb5bd2d 100644 --- a/benchmarks/geo_gnn/dev.yaml +++ b/benchmarks/geo_gnn/dev.yaml @@ -19,6 +19,6 @@ dimenet: method: per_gpu argv: --model: 'DimeNet' - --num-samples: 10000 + --num-samples: 100000 --use3d: True --batch-size: 512 \ No newline at end of file diff --git a/benchmarks/geo_gnn/main.py b/benchmarks/geo_gnn/main.py index 71e1c8827..b8875d2bf 100644 --- a/benchmarks/geo_gnn/main.py +++ b/benchmarks/geo_gnn/main.py @@ -78,14 +78,20 @@ def train_degree(train_dataset): # Compute the maximum in-degree in the training data. max_degree = -1 for data in train_dataset: - d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) - max_degree = max(max_degree, int(d.max())) + try: + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + max_degree = max(max_degree, int(d.max())) + except TypeError: + pass # Compute the in-degree histogram tensor deg = torch.zeros(max_degree + 1, dtype=torch.long) for data in train_dataset: - d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) - deg += torch.bincount(d, minlength=deg.numel()) + try: + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += torch.bincount(d, minlength=deg.numel()) + except TypeError: + pass return deg @@ -109,13 +115,14 @@ def batch_size(x): observer = BenchObserver(batch_size_fn=batch_size) train_dataset = PCQM4Mv2Subset(args.num_samples, args.root) + degree = train_degree(train_dataset) sample = next(iter(train_dataset)) info = models[args.model]( args, sample=sample, - degree=lambda: train_degree(train_dataset), + degree=lambda: degree, ) TRAIN_mean, TRAIN_std = ( diff --git a/benchmarks/geo_gnn/requirements-pre.cuda.txt b/benchmarks/geo_gnn/requirements-pre.cuda.txt index 0ec4d88dd..6c76b0c91 100644 --- a/benchmarks/geo_gnn/requirements-pre.cuda.txt +++ b/benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -10,7 +10,7 @@ --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html --trusted-host pypi.ngc.nvidia.com -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -19,19 +19,19 @@ fsspec==2024.6.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -139,7 +139,7 @@ scipy==1.14.1 # -c .pin/../.pin/constraints-cuda-torch.txt # jax # jaxlib -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/geo_gnn/requirements.cuda.txt b/benchmarks/geo_gnn/requirements.cuda.txt index 88e329e6d..37b77babb 100644 --- a/benchmarks/geo_gnn/requirements.cuda.txt +++ b/benchmarks/geo_gnn/requirements.cuda.txt @@ -54,7 +54,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -81,22 +81,22 @@ idna==3.10 # -c .pin/../.pin/constraints-cuda-torch.txt # requests # yarl -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -316,7 +316,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt diff --git a/benchmarks/huggingface/requirements.cuda.txt b/benchmarks/huggingface/requirements.cuda.txt index d4bcacca7..d4323b4af 100644 --- a/benchmarks/huggingface/requirements.cuda.txt +++ b/benchmarks/huggingface/requirements.cuda.txt @@ -34,7 +34,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub @@ -51,7 +51,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tokenizers @@ -60,19 +60,19 @@ idna==3.10 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -253,7 +253,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/lightning/main.py b/benchmarks/lightning/main.py index b31f3880c..aca89ee47 100644 --- a/benchmarks/lightning/main.py +++ b/benchmarks/lightning/main.py @@ -40,7 +40,7 @@ def prepare_voir(): observer = BenchObserver( accelerator.Event, - earlystop=65, + earlystop=100, batch_size_fn=lambda x: len(x[0]), raise_stop_program=False, stdout=True, @@ -73,8 +73,6 @@ def main(): model = TorchvisionLightning(model) - - accelerator.set_enable_tf32(True) observer, monitor = prepare_voir() @@ -91,10 +89,10 @@ def main(): enable_checkpointing=False, enable_progress_bar=False, reload_dataloaders_every_n_epochs=1, - max_steps=100 + max_steps=120 ) - with monitor(): + with monitor(poll_interval=0.1): trainer.fit(model=model, train_dataloaders=loader) print("finished: ", rank) diff --git a/benchmarks/lightning/requirements.cuda.txt b/benchmarks/lightning/requirements.cuda.txt index d6823c252..db0745882 100644 --- a/benchmarks/lightning/requirements.cuda.txt +++ b/benchmarks/lightning/requirements.cuda.txt @@ -46,7 +46,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -75,19 +75,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchcompat -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -277,7 +277,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/llama/requirements.cuda.txt b/benchmarks/llama/requirements.cuda.txt index 7d972b40f..1f52de100 100644 --- a/benchmarks/llama/requirements.cuda.txt +++ b/benchmarks/llama/requirements.cuda.txt @@ -68,7 +68,7 @@ fairscale==0.4.13 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llama/requirements.in -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -96,7 +96,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -107,19 +107,19 @@ idna==3.10 # -c .pin/../.pin/constraints-cuda-torch.txt # requests # yarl -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -334,7 +334,7 @@ six==1.16.0 # asttokens # fire # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/llava/requirements.cuda.txt b/benchmarks/llava/requirements.cuda.txt index 02cc24fbc..91e94c4bf 100644 --- a/benchmarks/llava/requirements.cuda.txt +++ b/benchmarks/llava/requirements.cuda.txt @@ -68,7 +68,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -92,7 +92,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -104,19 +104,19 @@ idna==3.10 # -c .pin/../.pin/constraints-cuda-torch.txt # requests # yarl -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -335,7 +335,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt index 0e1e0010a..3abff6b50 100644 --- a/benchmarks/llm/requirements.cuda.txt +++ b/benchmarks/llm/requirements.cuda.txt @@ -82,7 +82,7 @@ fairscale==0.4.13 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.in # -r benchmarks/llm/requirements.txt -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # blobfile @@ -115,7 +115,7 @@ hjson==3.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -132,19 +132,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -378,7 +378,7 @@ six==1.16.0 # asttokens # fire # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -402,11 +402,11 @@ torch==2.4.0+cu121 # accelerate # fairscale # xformers -torchao==0.3.1+cu121 +torchao==0.5.0+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt - # torchtune -torchtune==0.2.1+cu121 + # -r benchmarks/llm/requirements.in +torchtune==0.3.0+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.in diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index 91b62c073..36832ad67 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -4,6 +4,7 @@ torch PyYAML argklass fairscale +torchao # Prepare accelerate diff --git a/benchmarks/purejaxrl/dqn.py b/benchmarks/purejaxrl/dqn.py index 17c839147..fc0a97b8d 100644 --- a/benchmarks/purejaxrl/dqn.py +++ b/benchmarks/purejaxrl/dqn.py @@ -50,7 +50,9 @@ def make_train(config): config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"] from benchmate.timings import StepTimer + from benchmate.jaxmem import memory_peak_fetcher step_timer = StepTimer(give_push()) + fetch_memory_peak = memory_peak_fetcher() basic_env, env_params = gymnax.make(config["ENV_NAME"]) env = FlattenObservationWrapper(basic_env) @@ -238,6 +240,7 @@ def callback(metrics): step_timer.step(delta.item()) step_timer.log(returns=returns, loss=loss) + step_timer.log(memory_peak=fetch_memory_peak(), units="MiB") step_timer.end() jax.debug.callback(callback, metrics) @@ -258,12 +261,49 @@ def callback(metrics): return train +# When using nvidia-smi to monitor memory +# arg: --buffer_size +# model: +# 256: 61900.25 MiB +# 1000: 61900.25 MiB +# 10000: 61900.25 MiB + +# dqn: +# arg: --num_envs +# model: +# 2: 61900.25 MiB +# 4: 61900.25 MiB +# 16: 61900.25 MiB +# 32: 61900.25 MiB +# 64: 61900.25 MiB +# 128: 61900.25 MiB + +# arg: --total_timesteps +# model: +# 32768: 61900.25 MiB +# 65536: 61900.25 MiB + +# When using Jax to monitor memory + +# dqn.D0 [stdout] Device: cuda:0 +# dqn.D0 [stdout] num_allocs: 0.0006799697875976562 MiB +# dqn.D0 [stdout] bytes_in_use: 0.915771484375 MiB +# dqn.D0 [stdout] peak_bytes_in_use: 80.41552734375 MiB +# dqn.D0 [stdout] largest_alloc_size: 16.07958984375 MiB +# dqn.D0 [stdout] bytes_limit: 60832.359375 MiB +# dqn.D0 [stdout] bytes_reserved: 0.0 MiB +# dqn.D0 [stdout] peak_bytes_reserved: 0.0 MiB +# dqn.D0 [stdout] largest_free_block_bytes: 0.0 MiB +# dqn.D0 [stdout] pool_bytes: 60832.359375 MiB +# dqn.D0 [stdout] peak_pool_bytes: 60832.359375 MiB + + @dataclass class Arguments: - num_envs: int = 10 - buffer_size: int = 10000 + num_envs: int = 10 # No impact on memory + buffer_size: int = 10000 # No impact on memory buffer_batch_size: int = 128 - total_timesteps: int = 100_000 + total_timesteps: int = 100_000 # No impact on memory epsilon_start: float = 1.0 epsilon_finish: float = 0.05 epsilon_anneal_time: int = 25e4 diff --git a/benchmarks/purejaxrl/ppo.py b/benchmarks/purejaxrl/ppo.py index a053373f3..0cc8896cc 100644 --- a/benchmarks/purejaxrl/ppo.py +++ b/benchmarks/purejaxrl/ppo.py @@ -75,7 +75,10 @@ class Transition(NamedTuple): def make_train(config): from benchmate.timings import StepTimer + from benchmate.jaxmem import memory_peak_fetcher + step_timer = StepTimer(give_push()) + fetch_memory_peak = memory_peak_fetcher() config["NUM_UPDATES"] = ( config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] @@ -280,6 +283,7 @@ def callback(info): step_timer.step(config["NUM_ENVS"] * config["NUM_STEPS"]) step_timer.log(loss=loss) + step_timer.log(memory_peak=fetch_memory_peak(), units="MiB") step_timer.end() jax.debug.callback(callback, metrics) diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt index a59468762..d495163a9 100644 --- a/benchmarks/purejaxrl/requirements.cuda.txt +++ b/benchmarks/purejaxrl/requirements.cuda.txt @@ -157,7 +157,7 @@ farama-notifications==0.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # gymnasium -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -269,7 +269,7 @@ itsdangerous==2.2.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flask -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt @@ -286,15 +286,15 @@ jax[cuda12]==0.4.31 # optax # orbax-checkpoint # rlax -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -366,12 +366,12 @@ msgpack==1.1.0 # -c .pin/../.pin/constraints-cuda-torch.txt # flax # orbax-checkpoint -mujoco==3.2.2 +mujoco==3.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax # mujoco-mjx -mujoco-mjx==3.2.2 +mujoco-mjx==3.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -506,7 +506,7 @@ optax==0.2.3 # -r benchmarks/purejaxrl/requirements.in # brax # flax -orbax-checkpoint==0.6.3 +orbax-checkpoint==0.6.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -537,7 +537,7 @@ pillow==10.4.0 # brax # matplotlib # navix -platformdirs==4.3.3 +platformdirs==4.3.6 # via # -c .pin/../.pin/constraints-cuda-torch.txt # black @@ -547,7 +547,7 @@ pluggy==1.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pytest -protobuf==5.28.1 +protobuf==5.28.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # orbax-checkpoint @@ -671,7 +671,7 @@ smmap==5.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # gitdb -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -734,7 +734,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix @@ -756,7 +756,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-torch.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/purejaxrl/requirements.in -wandb==0.18.0 +wandb==0.18.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix diff --git a/benchmarks/purejaxrl/voirfile.py b/benchmarks/purejaxrl/voirfile.py index 5305be3f4..a94eb7646 100644 --- a/benchmarks/purejaxrl/voirfile.py +++ b/benchmarks/purejaxrl/voirfile.py @@ -32,7 +32,7 @@ def instrument_main(ov, options: Config): ov.require(dash) ov.require( - log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + log("value", "progress", "rate", "units", "loss", "gpudata", "memory_peak", "cpudata", context="task"), # early_stop(n=options.stop, key="rate", task="train"), monitor_monogpu(poll_interval=options.gpu_poll), ) diff --git a/benchmarks/recursiongfn/requirements.cuda.txt b/benchmarks/recursiongfn/requirements.cuda.txt index 89c02624f..2c852b71d 100644 --- a/benchmarks/recursiongfn/requirements.cuda.txt +++ b/benchmarks/recursiongfn/requirements.cuda.txt @@ -46,7 +46,7 @@ blosc2==2.7.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tables -botorch==0.11.3 +botorch==0.12.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -79,7 +79,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -108,7 +108,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -gpytorch==1.12 +gpytorch==1.13 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -122,25 +122,26 @@ idna==3.10 # -c .pin/../.pin/constraints-cuda-torch.txt # requests # yarl -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxtyping==0.2.34 +jaxtyping==0.2.19 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # gpytorch # linear-operator jinja2==3.1.4 # via @@ -151,7 +152,7 @@ joblib==1.4.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # scikit-learn -linear-operator==0.5.2 +linear-operator==0.5.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # botorch @@ -183,6 +184,7 @@ mpmath==1.3.0 # -c .pin/../.pin/constraints-cuda-torch.txt # botorch # gpytorch + # linear-operator # sympy msgpack==1.1.0 # via @@ -215,9 +217,9 @@ numpy==1.26.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # blosc2 - # botorch # jax # jaxlib + # jaxtyping # ml-dtypes # numexpr # opt-einsum @@ -327,11 +329,11 @@ pillow==10.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # rdkit -platformdirs==4.3.3 +platformdirs==4.3.6 # via # -c .pin/../.pin/constraints-cuda-torch.txt # wandb -protobuf==5.28.1 +protobuf==5.28.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tensorboard @@ -437,7 +439,7 @@ smmap==5.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # gitdb -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -490,18 +492,20 @@ triton==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -typeguard==2.13.3 +typeguard==4.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jaxtyping - # linear-operator typing-extensions==4.12.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # botorch + # jaxtyping # multidict # reactivex # tables # torch + # typeguard tzdata==2024.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt @@ -520,7 +524,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-torch.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/recursiongfn/requirements.in -wandb==0.18.0 +wandb==0.18.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in diff --git a/benchmarks/rlhf/requirements.cuda.txt b/benchmarks/rlhf/requirements.cuda.txt index 12a24c6c4..acc448aee 100644 --- a/benchmarks/rlhf/requirements.cuda.txt +++ b/benchmarks/rlhf/requirements.cuda.txt @@ -74,7 +74,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -98,7 +98,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -110,19 +110,19 @@ idna==3.10 # -c .pin/../.pin/constraints-cuda-torch.txt # requests # yarl -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -342,7 +342,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -372,7 +372,7 @@ triton==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -trl==0.10.1 +trl==0.11.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/rlhf/requirements.in @@ -384,7 +384,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # trl diff --git a/benchmarks/timm/requirements.cuda.txt b/benchmarks/timm/requirements.cuda.txt index 4554f91ec..1ac873600 100644 --- a/benchmarks/timm/requirements.cuda.txt +++ b/benchmarks/timm/requirements.cuda.txt @@ -34,7 +34,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub @@ -50,7 +50,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/timm/requirements.in @@ -58,19 +58,19 @@ idna==3.10 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -245,7 +245,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/torchatari/requirements.cuda.txt b/benchmarks/torchatari/requirements.cuda.txt index 2b0aa99d6..0ed7f915c 100644 --- a/benchmarks/torchatari/requirements.cuda.txt +++ b/benchmarks/torchatari/requirements.cuda.txt @@ -64,7 +64,7 @@ farama-notifications==0.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # gymnasium -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -100,19 +100,19 @@ importlib-resources==6.4.5 # -c .pin/../.pin/constraints-cuda-torch.txt # cantilever # torchcompat -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -255,7 +255,7 @@ packaging==24.1 # -c .pin/../.pin/constraints-cuda-torch.txt # envpool # tensorboard -protobuf==5.28.1 +protobuf==5.28.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tensorboard @@ -298,7 +298,7 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # tensorboard -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -337,7 +337,7 @@ typing-extensions==4.12.2 # reactivex # torch # tyro -tyro==0.8.10 +tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/torchatari/requirements.in diff --git a/benchmarks/torchvision/requirements.cuda.txt b/benchmarks/torchvision/requirements.cuda.txt index 6b1a837f0..3b994c798 100644 --- a/benchmarks/torchvision/requirements.cuda.txt +++ b/benchmarks/torchvision/requirements.cuda.txt @@ -26,7 +26,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -44,19 +44,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchcompat -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -217,7 +217,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/torchvision/voirfile.py b/benchmarks/torchvision/voirfile.py index ed3f0af7c..a05c99774 100644 --- a/benchmarks/torchvision/voirfile.py +++ b/benchmarks/torchvision/voirfile.py @@ -24,7 +24,7 @@ class Config: stop: int = 20 # Number of seconds between each gpu poll - gpu_poll: int = 3 + gpu_poll: float = 1 @configurable diff --git a/benchmarks/torchvision_ddp/requirements.cuda.txt b/benchmarks/torchvision_ddp/requirements.cuda.txt index 28c6198b2..4e6a2a2b8 100644 --- a/benchmarks/torchvision_ddp/requirements.cuda.txt +++ b/benchmarks/torchvision_ddp/requirements.cuda.txt @@ -26,7 +26,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch @@ -44,19 +44,19 @@ importlib-resources==6.4.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchcompat -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -217,7 +217,7 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmarks/vjepa/main.py b/benchmarks/vjepa/main.py index 74ca606f7..18377b92e 100644 --- a/benchmarks/vjepa/main.py +++ b/benchmarks/vjepa/main.py @@ -650,7 +650,7 @@ def main(): if os.getenv("RANK", -1) != -1: acc.destroy_process_group() - sys.exit(0) + # sys.exit(0) if __name__ == "__main__": main() diff --git a/benchmarks/vjepa/requirements.cuda.txt b/benchmarks/vjepa/requirements.cuda.txt index c6e6ebb0e..2386bbd24 100644 --- a/benchmarks/vjepa/requirements.cuda.txt +++ b/benchmarks/vjepa/requirements.cuda.txt @@ -55,7 +55,7 @@ executing==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # varname -filelock==3.16.0 +filelock==3.16.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub @@ -71,7 +71,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.24.7 +huggingface-hub==0.25.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # timm @@ -79,19 +79,19 @@ idna==3.10 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests -jax[cuda12]==0.4.31 +jax[cuda12]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -jax-cuda12-pjrt==0.4.31 +jax-cuda12-pjrt==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin -jax-cuda12-plugin[with-cuda]==0.4.31 +jax-cuda12-plugin[with-cuda]==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax -jaxlib==0.4.31 +jaxlib==0.4.33 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -290,11 +290,11 @@ six==1.16.0 # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens # python-dateutil -submitit==1.5.1 +submitit==1.5.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/vjepa/requirements.in -sympy==1.13.2 +sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch diff --git a/benchmate/benchmate/jaxmem.py b/benchmate/benchmate/jaxmem.py new file mode 100644 index 000000000..1ead3bff5 --- /dev/null +++ b/benchmate/benchmate/jaxmem.py @@ -0,0 +1,30 @@ + + + +def memory_peak_fetcher(): + import jax + + def fetch_memory_peak(): + # 'memory', 'memory_stats' + devices = jax.devices() + max_mem = -1 + for device in devices: + # dqn.D0 [stdout] Device: cuda:0 + # dqn.D0 [stdout] num_allocs: 0.0006799697875976562 MiB + # dqn.D0 [stdout] bytes_in_use: 0.915771484375 MiB + # dqn.D0 [stdout] peak_bytes_in_use: 80.41552734375 MiB + # dqn.D0 [stdout] largest_alloc_size: 16.07958984375 MiB + # dqn.D0 [stdout] bytes_limit: 60832.359375 MiB + # dqn.D0 [stdout] bytes_reserved: 0.0 MiB + # dqn.D0 [stdout] peak_bytes_reserved: 0.0 MiB + # dqn.D0 [stdout] largest_free_block_bytes: 0.0 MiB + # dqn.D0 [stdout] pool_bytes: 60832.359375 MiB + # dqn.D0 [stdout] peak_pool_bytes: 60832.359375 MiB + + # device_name = str(device) + mem = device.memory_stats().get("peak_bytes_in_use", 0) / (1024 ** 2) + max_mem = max(mem, max_mem) + + return max_mem + + return fetch_memory_peak diff --git a/benchmate/benchmate/monitor.py b/benchmate/benchmate/monitor.py index 5d2624201..0ad34a3d3 100644 --- a/benchmate/benchmate/monitor.py +++ b/benchmate/benchmate/monitor.py @@ -17,7 +17,7 @@ @instrument_definition -def monitor_monogpu(ov, poll_interval=10, arch=None): +def monitor_monogpu(ov, poll_interval=1, arch=None): return monitor( ov, poll_interval=poll_interval, @@ -28,7 +28,7 @@ def monitor_monogpu(ov, poll_interval=10, arch=None): @instrument_definition -def monitor_node(ov, poll_interval=10, arch=None): +def monitor_node(ov, poll_interval=1, arch=None): return monitor( ov, poll_interval=poll_interval, @@ -49,7 +49,8 @@ def mblog(data): try: print(json.dumps(data), file=data_file) except ValueError: - print("Is bench ending?, ignoring ValueError") + pass + # print("Is bench ending?, ignoring ValueError") def get(): t = time.time() diff --git a/config/base.yaml b/config/base.yaml index 28a72afb7..b06cbea58 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -209,6 +209,11 @@ resnet50: resnet50-noio: inherits: _torchvision + voir: + options: + stop: 1000 + interval: "1s" + tags: - vision - classification @@ -372,12 +377,15 @@ focalnet: --model: focalnet_base_lrf brax: + # Brax requires very specific sizes to work + # so the resizer is not capable of handling resizing this bench inherits: _defaults tags: - rl - jax - multigpu - gym + - nobatch definition: ../benchmarks/brax group: brax install_group: torch @@ -699,7 +707,8 @@ dqn: argv: dqn: true --num_envs: auto({cpu_per_gpu}, 128) - --buffer_batch_size: 128 + --buffer_size: 131072 + --buffer_batch_size: 65536 --env_name: CartPole-v1 --training_interval: 10 @@ -712,7 +721,7 @@ ppo: --num_minibatches: 32 --update_epochs: 4 --env_name: hopper - --total_timesteps: 200000 + --total_timesteps: 2000000 _geo_gnn: inherits: _defaults @@ -724,14 +733,22 @@ _geo_gnn: plan: method: per_gpu +pna: + inherits: _geo_gnn + argv: + --model: 'PNA' + --num-samples: 100000 + --batch-size: 4096 + --num-workers: "auto({n_worker}, 0)" + dimenet: inherits: _geo_gnn - tags: - - monogpu argv: --model: 'DimeNet' - --num-samples: 10000 + --num-samples: 100000 --use3d: True + --batch-size: 16 + --num-workers: "auto({n_worker}, 0)" recursiongfn: inherits: _defaults @@ -745,7 +762,7 @@ recursiongfn: argv: --batch_size: 128 - --num_workers: 8 + --num_workers: "auto({n_worker}, 8)" --num_steps: 100 --layer_width: 128 --num_layers: 4 @@ -779,7 +796,7 @@ _llava: - monogpu argv: --batch_size: 1 - --num_workers: 4 + --num_workers: "auto({n_worker}, 4)" --gradient_accumulation_steps: 1 llava-single: @@ -788,7 +805,7 @@ llava-single: method: per_gpu argv: --batch_size: 1 - --num_workers: 4 + --num_workers: "auto({n_worker}, 4)" --gradient_accumulation_steps: 1 llava-gpus: @@ -800,7 +817,7 @@ llava-gpus: n: 1 argv: --batch_size: 1 - --num_workers: 4 + --num_workers: "auto({n_worker}, 4)" --gradient_accumulation_steps: 1 @@ -862,3 +879,23 @@ vjepa-gpus: plan: method: njobs n: 1 + +cleanrljax: + inherits: _defaults + install_group: torch + definition: ../benchmarks/cleanrl_jax + plan: + method: per_gpu + + # args.batch_size = int(args.num_envs * args.num_steps) + # args.minibatch_size = int(args.batch_size // args.num_minibatches) + # args.num_iterations = args.total_timesteps // args.batch_size + # --total_timesteps + # --num_steps + # --num_minibatches + + argv: + --num_envs: auto({cpu_per_gpu}, 128) + --num_steps: 128 + --num_minibatches: 4 + --total_timesteps: 10000000 \ No newline at end of file diff --git a/config/scaling.yaml b/config/scaling.yaml index 09f3f9ae5..d9d3dbf9e 100644 --- a/config/scaling.yaml +++ b/config/scaling.yaml @@ -55,7 +55,13 @@ bert-tf32-fp16: 112: 81140.75 MiB optimized: 128 bf16: {} -brax: {} +brax: + arg: --batch-size + model: + 1024: 4912.25 MiB +cleanrljax: + arg: --num_steps + optimized: 128 convnext_large-fp16: arg: --batch-size model: @@ -178,21 +184,59 @@ diffusion-nodes: 1: 21686.75 MiB 2: 21930.75 MiB 4: 23510.75 MiB + 16: 40054.25 MiB + 32: 61512.25 MiB diffusion-single: arg: --batch_size model: 1: 21654.75 MiB 2: 21818.75 MiB 4: 23478.75 MiB -dimenet: {} + 16: 33850.25 MiB + 32: 55354.25 MiB +dimenet: + arg: --batch-size + model: + 2: 452.6875 MiB + 4: 1604.25 MiB + 24: 4776.25 MiB + 56: 6330.25 MiB + 64: 12274.25 MiB + 112: 15294.25 MiB + 128: 13002.25 MiB + 240: 67506.25 MiB + 280: 56556.25 MiB + 488: 80406.25 MiB dinov2-giant-gpus: arg: train.batch_size_per_gpu={batch_size} model: - 32: 69614 MiB + 1: 32240.25 MiB + 2: 32252.25 MiB + 4: 32404.25 MiB + 16: 38350.25 MiB + 24: 48856.25 MiB + 32: 72102.25 MiB optimized: 32 +dinov2-giant-nodes: + arg: train.batch_size_per_gpu={batch_size} dinov2-giant-single: arg: train.batch_size_per_gpu={batch_size} + model: + 1: 20682.25 MiB + 2: 20682.25 MiB + 4: 20682.25 MiB + 16: 52748.25 MiB + 24: 60792.25 MiB + 32: 74544.25 MiB dlrm: {} +dqn: + arg: --buffer_batch_size + model: + 1024: 81.81005859375 MiB + 2048: 83.40380859375 MiB + 32768: 131.21630859375 MiB + 65536: 182.21630859375 MiB + optimized: 128 focalnet: arg: --batch-size model: @@ -216,6 +260,20 @@ fp16: {} fp32: {} lightning: arg: --batch-size + model: + 1: 1054.25 MiB + 2: 1054.25 MiB + 4: 1856.25 MiB + 16: 4728.25 MiB + 24: 5482.25 MiB + 32: 6352.25 MiB + 56: 1054.25 MiB + 64: 1856.25 MiB + 120: 14522.25 MiB + 128: 14818.25 MiB + 240: 25480.25 MiB + 488: 49042.25 MiB + 664: 65914.25 MiB lightning-gpus: arg: --batch-size model: @@ -224,21 +282,59 @@ lightning-gpus: 4: 1156.75 MiB 8: 1260.75 MiB 16: 4150.75 MiB + 48: 11056.25 MiB + 112: 16776.25 MiB 128: 15858 MiB + 240: 28942.25 MiB + 504: 54100.25 MiB + 624: 65386.25 MiB optimized: 16 llama: {} +llava-gpus: + arg: --batch_size + optimized: 1 +llava-single: + arg: --batch_size + model: + 1: 72614.25 MiB + 2: 15168.25 MiB + 4: 72362.25 MiB + optimized: 1 llm-full-mp-gpus: arg: batch_size={batch_size} + model: + 1: 48964.25 MiB + 2: 49214.25 MiB + 4: 51310.25 MiB + 16: 81536.25 MiB llm-full-mp-nodes: arg: batch_size={batch_size} + model: + 1: 37340.25 MiB + 2: 38112.25 MiB + 4: 39110.25 MiB + 16: 80638.25 MiB llm-lora-ddp-gpus: arg: batch_size={batch_size} model: 1: 12418.75 MiB + 2: 19026.25 MiB + 4: 25464.25 MiB + 16: 55834.25 MiB + 32: 80268.25 MiB llm-lora-ddp-nodes: arg: batch_size={batch_size} + model: + 2: 17202.25 MiB + 4: 23956.25 MiB + 16: 59730.25 MiB + 32: 68932.25 MiB llm-lora-mp-gpus: arg: batch_size={batch_size} + model: + 2: 38166.25 MiB + 4: 43464.25 MiB + 16: 77116.25 MiB llm-lora-single: arg: batch_size={batch_size} model: @@ -262,11 +358,32 @@ opt-6_7b-multinode: model: 1: 55380 MiB optimized: 1 +pna: + arg: --batch-size + model: + 4096: 39554.25 MiB +ppo: + arg: --num_steps + model: + 8: 80.791748046875 MiB + 16: 80.916748046875 MiB + 32: 81.166748046875 MiB + 64: 81.666748046875 MiB + 128: 82.666748046875 MiB + 1024: 96.666748046875 MiB + 2048: 132.484619140625 MiB + 4096: 205.328369140625 MiB + 2517448: 62094.25 MiB + optimized: 32 recursiongfn: arg: --batch_size model: 2: 1134.75 MiB 4: 1140.75 MiB + 16: 1830.25 MiB + 32: 1342.25 MiB + 64: 4410.25 MiB + 128: 9160.25 MiB reformer: arg: --batch-size model: @@ -376,6 +493,46 @@ resnet50: optimized: 64 resnet50-noio: arg: --batch-size + model: + 1: 1594.25 MiB + 2: 1652.25 MiB + 4: 1854.25 MiB + 16: 3052.25 MiB + 32: 4690.25 MiB + 56: 7114.25 MiB + 136: 15194.25 MiB + 288: 30632.25 MiB + 592: 64483.8125 MiB + 736: 76050.25 MiB +rlhf-gpus: + arg: --per_device_train_batch_size + model: + 1: 13448.25 MiB + 2: 13594.25 MiB + 4: 13686.25 MiB + 16: 14606.25 MiB + 32: 17918.25 MiB + 64: 24374.25 MiB + 128: 25830.25 MiB + 136: 29442.25 MiB + 392: 15372.25 MiB + 520: 15808.25 MiB + optimized: 64 +rlhf-single: + arg: --per_device_train_batch_size + model: + 1: 8590.25 MiB + 2: 8650.25 MiB + 4: 8822.25 MiB + 16: 9694.25 MiB + 32: 12952.25 MiB + 40: 14638.25 MiB + 64: 19422.25 MiB + 120: 31048.25 MiB + 128: 32442.25 MiB + 280: 63262.25 MiB + 352: 77536.25 MiB + optimized: 64 rwkv: arg: --micro_bsz model: @@ -424,8 +581,29 @@ torchatari: arg: --num-steps model: 1: 1124.75 MiB - 2: 1138.75 MiB - 4: 1166.75 MiB + 1024: 20176.25 MiB + 2048: 39020.25 MiB + 4096: 76708.25 MiB +vjepa-gpus: + arg: --batch_size + model: + 1: 27196.25 MiB + 2: 28896.25 MiB + 4: 30784.25 MiB + 16: 52722.25 MiB + 32: 77124.25 MiB + optimized: 24 +vjepa-single: + arg: --batch_size + model: + 1: 6644.25 MiB + 2: 18984.25 MiB + 4: 11860.25 MiB + 8: 30764.25 MiB + 16: 45516.25 MiB + 24: 57574.25 MiB + 32: 67122.25 MiB + optimized: 24 whisper: arg: --batch-size model: @@ -442,36 +620,3 @@ whisper: 128: 71634.375 MiB 144: 80412.75 MiB optimized: 128 - - -llava-single: - arg: --batch_size - optimized: 1 - -llava-gpus: - arg: --batch_size - optimized: 1 - -rlhf-single: - arg: --per_device_train_batch_size - optimized: 64 - -rlhf-gpus: - arg: --per_device_train_batch_size - optimized: 64 - -vjepa-single: - arg: --batch_size - optimized: 24 - -vjepa-gpus: - arg: --batch_size - optimized: 24 - -ppo: - arg: --num_minibatches - optimized: 32 - -dqn: - arg: --buffer_batch_size - optimized: 128 \ No newline at end of file diff --git a/config/standard.yaml b/config/standard.yaml index 588e35e9a..f32685cbc 100644 --- a/config/standard.yaml +++ b/config/standard.yaml @@ -161,12 +161,20 @@ dqn: ppo: enabled: true weight: 1.0 - + +cleanrljax: + enabled: false + weight: 1.0 + # Geo dimenet: enabled: true weight: 1.0 +pna: + enabled: true + weight: 1.0 + recursiongfn: enabled: true weight: 1.0 diff --git a/milabench/cli/gather.py b/milabench/cli/gather.py index d3058d65c..316b6bfb4 100644 --- a/milabench/cli/gather.py +++ b/milabench/cli/gather.py @@ -39,6 +39,7 @@ def arguments(): "--tags", type=str, help="Tags defined in run names", + nargs="+", default=default_tags(), ) return parser.parse_args() # Arguments() diff --git a/milabench/cli/list.py b/milabench/cli/list.py new file mode 100644 index 000000000..fda73bdf5 --- /dev/null +++ b/milabench/cli/list.py @@ -0,0 +1,56 @@ +import os +import yaml + +from milabench.config import build_config + + +this = os.path.dirname(__file__) +config = os.path.join(this, "..", "..", "config") + + +def list_missing_batch_resizer(): + standard = os.path.join(config, "standard.yaml") + scaling = os.path.join(config, "scaling.yaml") + + base_conf = build_config(standard) + + with open(scaling, "r") as fp: + scaling = yaml.safe_load(fp) + + missing_benches = [] + def add_bench(k, tags): + print(k, tags) + missing_benches.append(k) + + for k, v in base_conf.items(): + if k[0] == "_": + continue + + if not v.get("enabled", False): + continue + + tags = set(v.get("tags", [])) + + if "nobatch" in tags: + continue + + if k in scaling: + s = scaling[k].get("model", {}) + + if len(s) <= 1: + add_bench(k, tags) + else: + add_bench(k, tags) + + + + b = [f"\"{b}\"" for b in missing_benches] + + + + + print(" ".join(b)) + + +if __name__ == "__main__": + list_missing_batch_resizer() diff --git a/milabench/sizer.py b/milabench/sizer.py index b1f717247..75002edb3 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -248,7 +248,8 @@ def __init__(self): self.scaling = None self.benchname = None self.batch_size = 0 - self.max_usage = float("-inf") + self.max_usage = float("-inf") # Usage from the gpu monitor + self.peak_usage = float("-inf") # Usage provided by the bench itself (for jax) self.early_stopped = False def on_start(self, entry): @@ -259,6 +260,7 @@ def on_start(self, entry): self.benchname = entry.pack.config["name"] self.batch_size = None self.max_usage = float("-inf") + self.peak_usage = float("-inf") config = self.memory.setdefault(self.benchname, dict()) template = config.get("arg", None) @@ -300,6 +302,11 @@ def on_data(self, entry): if entry.data is None: return + memorypeak = entry.data.get("memory_peak") + if memorypeak is not None: + self.peak_usage = max(memorypeak, self.peak_usage) + return + gpudata = entry.data.get("gpudata") if gpudata is not None: current_usage = [] @@ -312,6 +319,11 @@ def on_data(self, entry): def on_stop(self, entry): self.early_stopped = True + def max_memory_usage(self): + if self.peak_usage != float("-inf"): + return self.peak_usage + return self.max_usage + def on_end(self, entry): if self.filepath is None: return @@ -319,7 +331,7 @@ def on_end(self, entry): if ( self.benchname is None or self.batch_size is None - or self.max_usage == float("-inf") + or self.max_memory_usage() == float("-inf") ): return @@ -328,12 +340,13 @@ def on_end(self, entry): if rc == 0 or self.early_stopped: config = self.memory.setdefault(self.benchname, dict()) model = config.setdefault("model", dict()) - model[self.batch_size] = f"{self.max_usage} MiB" + model[self.batch_size] = f"{self.max_memory_usage()} MiB" config["model"] = dict(sorted(model.items(), key=lambda x: x[0])) self.benchname = None self.batch_size = None self.max_usage = float("-inf") + self.peak_usage = float("-inf") def report(self, *args): if self.filepath is not None: diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index b7b31eed3..7328ca54b 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -84,13 +84,28 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then . $MILABENCH_WORDIR/env/bin/activate + milabench install --system $MILABENCH_WORDIR/system.yaml + # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS # pip install torch - # milabench pin --variant cuda --from-scratch $ARGS + # milabench pin --variant cuda --from-scratch # milabench install --system $MILABENCH_WORDIR/system.yaml --force $ARGS + # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS + + # ARGS="--select resnet50-noio,brax,lightning,dinov2-giant-single,dinov2-giant-gpus,llm-lora-ddp-gpus,llm-lora-ddp-nodes,llm-lora-mp-gpus,llm-full-mp-gpus,llm-full-mp-nodes,dqn,ppo,dimenet,llava-single,rlhf-single,rlhf-gpus,vjepa-single,vjepa-gpus" + + # MEMORY_CAPACITY=("4Go" "8Go" "16Go" "32Go" "64Go" "80Go") + # # MEMORY_CAPACITY=("2048" "4096" "8192") + + # # Run the benchmakrs + # for CAPACITY in "${MEMORY_CAPACITY[@]}"; do + # export MILABENCH_SIZER_AUTO=1 + # export MILABENCH_SIZER_MULTIPLE=8 + # export MILABENCH_SIZER_CAPACITY=$CAPACITY + # # export MILABENCH_SIZER_BATCH_SIZE=$CAPACITY + # milabench run --run-name "c$CAPACITY.{time}" --system $MILABENCH_WORDIR/system.yaml $ARGS || true + # done - # - # Run the benchmakrs milabench run --system $MILABENCH_WORDIR/system.yaml $ARGS #