Official implementation of
TD-MPC2: Scalable, Robust World Models for Continuous Control by
Nicklas Hansen, Hao Su*, Xiaolong Wang* (UC San Diego)
[Website] [Paper] [Models] [Dataset]
TD-MPC2 is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across 104 continuous control tasks spanning multiple domains, with a single set of hyperparameters (right). We further demonstrate the scalability of TD-MPC2 by training a single 317M parameter agent to perform 80 tasks across multiple domains, embodiments, and action spaces (left).
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC2 agents. We additionally open-source 300+ model checkpoints (including 12 multi-task models) across 4 task domains: DMControl, Meta-World, ManiSkill2, and MyoSuite, as well as our 30-task and 80-task datasets used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
You will need a machine with a GPU and at least 12 GB of RAM for single-task online RL with TD-MPC2, and 128 GB of RAM for multi-task offline RL on our provided 80-task dataset. A GPU with at least 8 GB of memory is recommended for single-task online RL and for evaluation of the provided multi-task models (up to 317M parameters). Training of the 317M parameter model requires a GPU with at least 24 GB of memory.
We provide a Dockerfile
for easy installation. You can build the docker image by running
cd docker && docker build . -t <user>/tdmpc2:1.0.0
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
If you prefer to install dependencies manually, start by installing dependencies via conda
by running the following command:
conda env create -f docker/environment.yaml
pip install gym==0.21.0
The environment.yaml
file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in environment.yaml
.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
python -m mani_skill2.utils.download_asset all
which downloads assets to ./data
. You may move these assets to any location. Then, add the following line to your ~/.bashrc
:
export MS2_ASSET_DIR=<path>/<to>/<data>
and restart your terminal. Meta-World additionally requires MuJoCo 2.1.0. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at https://www.tdmpc2.com/files/mjkey.txt. You can download the license by running
wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt
See docker/Dockerfile
for installation instructions if you do not already have MuJoCo 2.1.0 installed. MyoSuite requires gym==0.13.0
which is incompatible with Meta-World and ManiSkill2. Install separately with pip install myosuite
if desired. Depending on your existing system packages, you may need to install other dependencies. See docker/Dockerfile
for a list of recommended system packages.
This codebase currently supports 104 continuous control tasks from DMControl, Meta-World, ManiSkill2, and MyoSuite. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
domain | task |
---|---|
dmcontrol | dog-run |
dmcontrol | cheetah-run-backwards |
metaworld | mw-assembly |
metaworld | mw-pick-place-wall |
maniskill | pick-cube |
maniskill | pick-ycb |
myosuite | myo-key-turn |
myosuite | myo-key-turn-hard |
which can be run by specifying the task
argument for evaluation.py
. Multi-task training and evaluation is specified by setting task=mt80
or task=mt30
for the 80-task and 30-task sets, respectively.
As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks; use argument obs=rgb
if you wish to train visual policies.
We provide examples on how to evaluate our provided TD-MPC2 checkpoints, as well as how to train your own TD-MPC2 agents, below.
See below examples on how to evaluate downloaded single-task and multi-task checkpoints.
$ python evaluate.py task=mt80 model_size=48 checkpoint=/path/to/mt80-48M.pt
$ python evaluate.py task=mt30 model_size=317 checkpoint=/path/to/mt30-317M.pt
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
All single-task checkpoints expect model_size=5
. Multi-task checkpoints are available in multiple model sizes. Available arguments are model_size={1, 5, 19, 48, 317}
. Note that single-task evaluation of multi-task checkpoints is currently not supported. See config.yaml
for a full list of arguments.
See below examples on how to train TD-MPC2 on a single task (online RL) and on multi-task datasets (offline RL). We recommend configuring Weights and Biases (wandb
) in config.yaml
to track training progress.
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (model_size=5
). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are model_size={1, 5, 19, 48, 317}
. See config.yaml
for a full list of arguments.
As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments; use branch distributed
and argument world_size=N
to train on N
GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
If you find our work useful, please consider citing our paper as follows:
@inproceedings{hansen2024tdmpc2,
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
as well as the original TD-MPC paper:
@inproceedings{hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={International Conference on Machine Learning (ICML)},
year={2022}
}
You are very welcome to contribute to this project. Feel free to open an issue or pull request if you have any suggestions or bug reports, but please review our guidelines first. Our goal is to build a codebase that can easily be extended to new environments and tasks, and we would love to hear about your experience!
This project is licensed under the MIT License - see the LICENSE
file for details. Note that the repository relies on third-party code, which is subject to their respective licenses.