We recommend starting with a single host first and then moving to multihost.
-
Create a gcs buckets in your project for storing logs and checkpoints. To run maxtext the TPU/GPU VMs must have permission to read/write the gcs bucket. These permissions are granted by service account roles, such as the
STORAGE ADMIN
role. -
MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in
configs/base.yml
, this config includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from command lines. For instance you may change thesteps
orlog_period
by either modifyingconfigs/base.yml
or by passing insteps
andlog_period
as additional args to thetrain.py
call.base_output_directory
should be set to a folder in the bucket you just created.
Local development is a convenient way to run MaxText on a single host. It doesn't scale to multiple hosts.
- Create and SSH to the single-host VM of your choice. We recommend a
v4-8
. - Clone MaxText onto that TPUVM.
- Within the root directory of that
git
repo, install dependencies by running:
bash setup.sh
- After installation completes, run training with the command on synthetic data:
python3 MaxText/train.py MaxText/configs/base.yml \
run_name=$YOUR_JOB_NAME \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10
Next, you can try training on a HugginFace dataset, see Data Input Pipeline for data input options.
- If you want to decode, you can decode as follows.
python3 MaxText/decode.py MaxText/configs/base.yml \
run_name=$YOUR_JOB_NAME \
base_output_directory=gs://<my-bucket> \
per_device_batch_size=1
Be aware, these decodings will be random. To get high quality decodings you need pass in a checkpoint, typically via the load_parameters_path
argument.
- Use
bash docker_build_dependency_image.sh DEVICE=gpu
can be used to build a container with the required dependencies. - After installation is completed, run training with the command on synthetic data:
python3 MaxText/train.py MaxText/configs/base.yml \
run_name=$YOUR_JOB_NAME \
base_output_directory=gs://<my-bucket> \
dataset_type=synthetic \
steps=10
- If you want to decode, you can decode as follows.
python3 MaxText/decode.py MaxText/configs/base.yml \
run_name=$YOUR_JOB_NAME \
base_output_directory=gs://<my-bucket> \
per_device_batch_size=1
- If you see the following error when running inside a container, set a larger
--shm-size
(e.g.--shm-size=1g
)
Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.all_reduce' failed: external/xla/xla/service/gpu/nccl_utils.cc:297: NCCL operation ncclCommInitRank(&comm, nranks, id, rank) failed: unhandled cuda error (run with NCCL_DEBUG=INFO for details); current tracing scope: all-reduce-start.2; current profiling annotation: XlaModule:#hlo_module=jit__unnamed_wrapped_function_,program_id=7#.
There are three patterns for running MaxText with more than one host.
- [GKE, recommended] Running Maxtext with xpk - Quick Experimentation and Production support
- [GCE] Running Maxtext with Multihost Jobs - Long Running Production Jobs with Queued Resources
- [GCE] Running Maxtext with Multihost Runner - Fast experiments via multiple ssh connections.
Once you've gotten workloads running, there are important optimizations you might want to put on your cluster. Please check the doc PREFLIGHT.md