This repository provides the official implementation of the Q-Score Matching algorithm (QSM) for the paper "Learning a Diffusion Model Policy from Rewards via Q-Score Matching" by Michael Psenka, Alejandro Escontrela, Pieter Abbeel, and Yi Ma. The setting for QSM is off-policy reinforcement learning in continuous state/action spaces, where the agent's policy is represented as a diffusion model. The core of the QSM algorithm: by iteratively aligning the denoising model of the policy with the action gradient of the critic
Diffusion models have gained popularity in generative tasks due to their ability to represent complex distributions over continuous spaces. In the context of reinforcement learning, they offer both expressiveness and ease of sampling, making them a promising choice for policy representation. While many works have been done in the offline setting, the online/off-policy setting for diffusion model policies is still relatively underexplored.
The code is built on top of a re-implementation of the jaxrl framework.
To get started, you need to install the required dependencies. Ensure you have Python 3.8+ and a suitable GPU setup.
pip install --upgrade pip
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f
For other versions of CUDA, follow the JAX installation instructions.
To reproduce the results or to run the provided training scripts, navigate to the respective example directories and execute the provided training script:
cd examples/states
Script options can be found within the training script file. For example, if you want to train on a different environment:
python3 --env_name walker_walk
Main Training Script: The main training script to train a diffusion model agent using QSM. Includes options for the environment and training scenario.
QSM Learner: The core implementation of the QSM algorithm, including methods for creating the learner, updating critic and actor networks, and sampling actions. Note that if you want to make any changes to the learner after installation, you will need to reinstall jaxrl5 locally, by running the following from the root directory of the repository:
pip install ./
Training Configuration for Score Matching: Configuration file for setting hyperparameters and model configurations for the QSM learner.
DDPM Implementation: Contains the implementation of Denoising Diffusion Probabilistic Models (DDPM), Fourier Features, and various beta schedules essential for the score matching process.
The main training script gives a minimal example for launching a QSM learning agent in an environment. Below is a slightly stripped down version to illustrate the usage of jaxrl5 and the QSM learner:
import os
import jax
import gym
import tqdm
from absl import app, flags
from ml_collections import config_flags
from jaxrl5.agents import ScoreMatchingLearner
from import ReplayBuffer
from jaxrl5.evaluation import evaluate
from jaxrl5.wrappers import wrap_gym
flags.DEFINE_string("env_name", "CartPole-v1", "Environment name.")
flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_integer("eval_episodes", 1, "Number of episodes used for evaluation.")
flags.DEFINE_integer("log_interval", 1000, "Logging interval.")
flags.DEFINE_integer("eval_interval", 10000, "Eval interval.")
flags.DEFINE_integer("batch_size", 256, "Mini batch size.")
flags.DEFINE_integer("max_steps", int(1e6), "Number of training steps.")
flags.DEFINE_integer("start_training", int(1e4), "Number of steps to start training.")
config_flags.DEFINE_config_file("config", "examples/states/configs/", "Training configuration file.")
def main(_):
env = gym.make(FLAGS.env_name)
env = wrap_gym(env, rescale_actions=True)
eval_env = gym.make(FLAGS.env_name)
eval_env = wrap_gym(eval_env, rescale_actions=True)
eval_env.seed(FLAGS.seed + 42)
config = FLAGS.config
kwargs = dict(config)
agent = ScoreMatchingLearner.create(FLAGS.seed, env.observation_space, env.action_space, **kwargs)
replay_buffer = ReplayBuffer(env.observation_space, env.action_space, FLAGS.max_steps)
observation, done = env.reset(), False
for step in tqdm.tqdm(range(1, FLAGS.max_steps + 1), smoothing=0.1):
if step < FLAGS.start_training:
action = env.action_space.sample()
action, agent = agent.sample_actions(observation)
next_observation, reward, done, info = env.step(action)
if not done or "TimeLimit.truncated" in info:
mask = 1.0
mask = 0.0
"observations": observation,
"actions": action,
"rewards": reward,
"masks": mask,
"dones": done,
"next_observations": next_observation
observation = next_observation
if done:
observation, done = env.reset(), False
if step >= FLAGS.start_training:
batch = replay_buffer.sample(FLAGS.batch_size)
agent, _ = agent.update(batch)
if step % FLAGS.log_interval == 0:
print(f"Step: {step}")
if step % FLAGS.eval_interval == 0:
eval_info = evaluate(agent, eval_env, num_episodes=FLAGS.eval_episodes)
print(f"Evaluation at step {step}: {eval_info}")
if __name__ == "__main__":
We welcome contributions to enhance the repository. If you encounter any issues or have suggestions, feel free to open an issue or a pull request.
If you use this code or our QSM algorithm in your research, please cite our paper:
title={Learning a Diffusion Model Policy from Rewards via Q-Score Matching},
author={Psenka, Michael and Escontrela, Alejandro and Abbeel, Pieter and Ma, Yi},
booktitle={Proceedings of the 41st International Conference on Machine Learning},
This project is licensed under the MIT License. See the LICENSE file for details.