Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/rainbow #86

Merged
merged 27 commits into from
Jun 20, 2024
Merged

Feat/rainbow #86

merged 27 commits into from
Jun 20, 2024

Conversation

RPegoud
Copy link
Contributor

@RPegoud RPegoud commented Jun 3, 2024

What?

Implement Rainbow DQN as defined in #81.

Progress:

  • NoisyLinear layer module
  • NoisyMLPTorso module

Comments

NoisyDQN generates a Gaussian noise matrix at inference time, the NoisyLinear module currently uses a "noise" rng stream for this purpose:

row_noise = jax.random.normal(self.make_rng("noise"), (n_rows,))
col_noise = jax.random.normal(self.make_rng("noise"), (n_cols,))
noise_matrix = jnp.outer(row_noise, col_noise)

This requires using the following syntax:

net = NoisyLinear(128)
rngs = {"params": jax.random.PRNGKey(0), "noise": jax.random.PRNGKey(1)}
params = net.init(rngs, x)
...
new_rng_key = jax.random.split(rngs["noise"])
y = net.apply(params, x, rngs={"params": ..., "noise": new_rng_key})

Is this the correct way to proceed? If so, what would be the easiest way to pass this additional key through the training script?

@EdanToledo
Copy link
Owner

What?

Implement Rainbow DQN as defined in #81.

Progress:

  • NoisyLinear layer module
  • NoisyMLPTorso module

Comments

NoisyDQN generates a Gaussian noise matrix at inference time, the NoisyLinear module currently uses a "noise" rng stream for this purpose:

row_noise = jax.random.normal(self.make_rng("noise"), (n_rows,))
col_noise = jax.random.normal(self.make_rng("noise"), (n_cols,))
noise_matrix = jnp.outer(row_noise, col_noise)

This requires using the following syntax:

net = NoisyLinear(128)
rngs = {"params": jax.random.PRNGKey(0), "noise": jax.random.PRNGKey(1)}
params = net.init(rngs, x)
...
new_rng_key = jax.random.split(rngs["noise"])
y = net.apply(params, x, rngs={"params": ..., "noise": new_rng_key})

Is this the correct way to proceed? If so, what would be the easiest way to pass this additional key through the training script?

I've left some comments, main thing im curious about is whether or not noise is turned off during evaluation. If so then you need to add that into the noisy linear layer.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 3, 2024

Got it, thanks for the feedback. The original paper does not mention deactivating noise during evaluation, however, the Rainbow paper gives some more insights:

Evaluation Methodology.

...
DQN starts with an exploration ε of 1, corresponding to acting uniformly at random; it anneals the amount of exploration over the first 4M frames, to a final value of 0.1 (lowered to 0.01 in later variants). Whenever using Noisy Nets, we acted fully greedily (ε = 0), with a value of 0.5 for the σ0 hyper-parameter used to initialize the weights in the noisy stream1. For agents without Noisy Nets, we used ε-greedy but decreased the exploration rate faster than was previously used, annealing ε to 0.01 in the first 250K frames.
...
In terms of median performance, the agent performed better when Noisy Nets were included; when these are removed and exploration is delegated to the traditional ε-greedy mechanism, performance was worse in aggregate (red line in Figure 3). While the removal of Noisy Nets produced a large drop in performance for several games, it also provided small increases in other games

I'm not completely sure but this seems to suggest that they maintain some exploration mechanism during evaluation, either ε-greedy with annealing or noisy nets.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 4, 2024

Hey, I'm getting this error when running make run example="stoix/systems/q_learning/ff_noisy_dqn.py", any ideas on how to fix this? Looks like it could be related to the Flashbax version?

==========
== CUDA ==
==========

CUDA Version 11.8.0

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

Traceback (most recent call last):
  File "/home/app/stoix/stoix/systems/q_learning/ff_noisy_dqn.py", line 6, in <module>
    import flashbax as fbx
  File "/stoix/lib/python3.10/site-packages/flashbax/__init__.py", line 16, in <module>
    from flashbax.buffers import (
  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/__init__.py", line 16, in <module>
    from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/prioritised_flat_buffer.py", line 25, in <module>
    from flashbax.buffers.prioritised_trajectory_buffer import (
  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/prioritised_trajectory_buffer.py", line 39, in <module>
    from flashbax.buffers import sum_tree, trajectory_buffer
  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/sum_tree.py", line 33, in <module>
    from flax.struct import dataclass
  File "/stoix/lib/python3.10/site-packages/flax/__init__.py", line 24, in <module>
    from flax import core
  File "/stoix/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/stoix/lib/python3.10/site-packages/flax/core/axes_scan.py", line 23, in <module>
    from jax.extend import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax.extend' (/stoix/lib/python3.10/site-packages/jax/extend/__init__.py)
make: *** [Makefile:29: run] Error 1

@EdanToledo
Copy link
Owner

Hey, I'm getting this error when running make run example="stoix/systems/q_learning/ff_noisy_dqn.py", any ideas on how to fix this? Looks like it could be related to the Flashbax version?


==========

== CUDA ==

==========



CUDA Version 11.8.0



Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.



This container image and its contents are governed by the NVIDIA Deep Learning Container License.

By pulling and using the container, you accept the terms and conditions of this license:

https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license



A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.



Traceback (most recent call last):

  File "/home/app/stoix/stoix/systems/q_learning/ff_noisy_dqn.py", line 6, in <module>

    import flashbax as fbx

  File "/stoix/lib/python3.10/site-packages/flashbax/__init__.py", line 16, in <module>

    from flashbax.buffers import (

  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/__init__.py", line 16, in <module>

    from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer

  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/prioritised_flat_buffer.py", line 25, in <module>

    from flashbax.buffers.prioritised_trajectory_buffer import (

  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/prioritised_trajectory_buffer.py", line 39, in <module>

    from flashbax.buffers import sum_tree, trajectory_buffer

  File "/stoix/lib/python3.10/site-packages/flashbax/buffers/sum_tree.py", line 33, in <module>

    from flax.struct import dataclass

  File "/stoix/lib/python3.10/site-packages/flax/__init__.py", line 24, in <module>

    from flax import core

  File "/stoix/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>

    from .axes_scan import broadcast as broadcast

  File "/stoix/lib/python3.10/site-packages/flax/core/axes_scan.py", line 23, in <module>

    from jax.extend import linear_util as lu

ImportError: cannot import name 'linear_util' from 'jax.extend' (/stoix/lib/python3.10/site-packages/jax/extend/__init__.py)

make: *** [Makefile:29: run] Error 1

Aah weird, it seems to most likely be a jax versioning issue. To be honest, I haven't done thorough testing on the docker file and using makefile since I always simply clone and install in conda env.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 5, 2024

I made some minor changes to the docker config and the evaluation function and completed a first run with NoisyDQN on CartPole. Here are the metrics I obtained after a million timesteps and 125 000 updates:

ABSOLUTE - 
Steps per second: 83988.091 
Episode length mean: 466.862 
Episode length std: 73.619 
Episode length min: 203.000 
Episode length max: 500.000 
Episode return mean: 466.862 
Episode return std: 73.619 
Episode return min: 203.000 
Episode return max: 500.000

Does it compare with what you generally observe with vanilla DQN?

@EdanToledo
Copy link
Owner

Okay great, so I can't confidently say it's similar without running it again (I don't really remember) but it's great to see its learning. Let's maybe test it only a slightly harder env, maybe one of the minatar ones or Jumanji snake 6x6. If we are happy with that then let's move on to the other components of rainbow such as n step and c51 head.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 5, 2024

Sounds good, I'm still struggling to run the training script with GPU support but I should still be able to get a few runs tomorrow.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 6, 2024

I ran it on Jumanji/snake (probably 12x12 by default?) for 125000 updates and it seems to have learned as well:

ABSOLUTE
Steps per second: 58488.787 
Episode length mean: 1299.893 
Episode length std: 1784.384 
Episode length min: 32.000 
Episode length max: 4000.000 
Episode return mean: 14.586 
Episode return std: 2.967 
Episode return min: 0.000 
Episode return max: 21.000

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 6, 2024

As far as I can tell for n-step C51, we'd need to:

  • Replace the replay buffer by an item buffer
  • Change rollout_length to n (3 in the rainbow paper)
  • Define the multi-step variant of the loss as (so mainly replacing the reward by the discounted return):

$$ \begin{align} &d^n_t = R_t^{(n)} + \gamma^nz, p_{\bar \theta}(S_{t+n}, a^\star_{t+n}) \ \\ & \text{with } R_t^{(n)} = \sum_{k=0}^{n-1} \gamma^{k}R_{t+k+1} \end{align} $$

How do you view the API? Would this new version be a new standalone system or should we rather modify the current c51 system to support n_step returns?

@EdanToledo
Copy link
Owner

Firstly, I think we should use a trajectory buffer for the n-step returns, this will make our lives much easier and allow us more control over the n steps. Secondly, we should make a new file that will ultimately become rainbow instead of modifying the current c51 imo.

@EdanToledo
Copy link
Owner

I ran it on Jumanji/snake (probably 12x12 by default?) for 125000 updates and it seems to have learned as well:


ABSOLUTE

Steps per second: 58488.787 

Episode length mean: 1299.893 

Episode length std: 1784.384 

Episode length min: 32.000 

Episode length max: 4000.000 

Episode return mean: 14.586 

Episode return std: 2.967 

Episode return min: 0.000 

Episode return max: 21.000

Great, it's 6x6 in my default configs. Getting 14 is okay. Anything above 32+ is good performance but I imagine this is simply due to hyperparams. Getting 14 is definitely learning so happy with that.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 14, 2024

I see, this is a lot cleaner indeed. C51 seems to have similar issues on Snake as well:

--- Rainbow ---
ABSOLUTE
Episode length mean: 3959.477
Episode length std: 400.058
Episode length min: 4.000
Episode length max: 4000.000
Episode return mean: 0.109
Episode return std: 0.347
Episode return min: 0.000
Episode return max: 3.000

--- C51 ---
ABSOLUTE
Episode length mean: 3857.073
Episode length std: 740.275
Episode length min: 7.000
Episode length max: 4000.000
Episode return mean: 0.245
Episode return std: 0.628
Episode return min: 0.000
Episode return max: 6.000

I'll try to take a closer look

@EdanToledo
Copy link
Owner

@RPegoud I played around with configs and current rainbow system is working on snake and getting reward.

@EdanToledo
Copy link
Owner

@RPegoud okay so i've added prioritised replay and distributional dueling head. Please look over it and see if you have any opinions on how its been done and try verify for correctness. I think we are almost finished and ready for verification and clean up. Last thing to do is to use noisy layers in the dueling network i think and then i believe we are finished with rainbow then we just have to sanity check.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 15, 2024

Great! I should have time to take a look tomorrow or Monday at worst and I'll add the noisy layers 👍

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 17, 2024

Looks like we're really close! Here's what I've noted so far:

  • The importance sampling weight should be linearly annealed to 1
  • The dueling architecture has a "shared representation" i.e. a CNN layer common to the value and advantage streams
    I'll add a commit in the afternoon including these changes and the noisy layers and we can continue from there.

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 17, 2024

Great! For the scheduling use an optax scheduler just so we are consistent and allow an easy change of the type of scheduler and for the dueling arch I was thinking we need to change the current way it's constructed. The general design paradigm for other arch's is to have a torso and a head. Let's make the concept of a dueling architecture just a head instead of the torso. That way it's independent of any torso someone wants to use and allows for the shared torso easily.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 17, 2024

Yes I figured optax.linear_schedule would do the job but I'm not quite sure how to access the global timestep to get the current learning rate. Should we add a counter to the info dictionary or is there a better way to keep track of it?
I agree, it makes more sense to keep the shared layers separated from the dueling head.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 17, 2024

Here's a first commit with the Rainbow dueling head. I've also included the hyperparameters from the paper but can't run a full training locally. Let me know when you have some time to run an experiment, if those parameters don't work well enough, we can probably revert to the previous ones.

PS: as we use CNNs by default with Rainbow, is there a way to override env.flatten_observation=False in the config? So that the script runs out of the box without the additional flag:
python stoix/systems/q_learning/ff_rainbow.py env=jumanji/snake env.flatten_observation=False

@EdanToledo
Copy link
Owner

Amazing, will try run some tests later and do some review. Regarding env flatten, we can just turn it off by default i dont mind. Lastly, i'm thinking we remove noisy dqn as a system as it isn't really relevant for someone. I imagine someone will either use rainbow or some other variation of dqn that doesnt use noisy layers.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 17, 2024

Cool! Yes that seems reasonable, the training script is just a copy of standard dqn with some added keys, it doesn't justify an additional system.

@EdanToledo
Copy link
Owner

@RPegoud I've done some small tests, it seems to work (with some slight changes to hyperparams but thats to be expected since rainbow hyperparameters are hyper optimised for atari). I've done some clean up so I'd say we are pretty much done after we do a scheduler for the importance weight. I realised i didnt answer your one question, regarding tracking the number of training steps, i think we will need to either add an explicit variable to the learner state or somehow get it from the optimizer state.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 18, 2024

@EdanToledo I just added the scheduler, I ended up using:

importance_sampling_exponent_scheduler: Callable = optax.linear_schedule(
        init_value=config.system.importance_sampling_exponent,
        end_value=1,
        transition_steps=config.arch.num_updates * config.system.epochs,
        transition_begin=0,
    )

With:

step_count = optax.tree_utils.tree_get(opt_states, "count")
importance_sampling_exponent = scheduler_fn(step_count)

With this setup, the exponent should reach 1 at the end of training but I haven't been able to run it entirely to check. Let me know if this works as expected.

@EdanToledo
Copy link
Owner

okay so im happy with the code, now the last thing we need to do is make it pass the pre-commit linters and typing tests.

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 18, 2024

Awesome, I'll look into it tomorrow

@RPegoud
Copy link
Contributor Author

RPegoud commented Jun 19, 2024

@EdanToledo I've been looking into the typing tests and there seems to be a problem with actor_apply, as Rainbow uses an additional rngs argument. I modified the type to:

ActorApply = Callable[
    [FrozenDict, Observation, Optional[dict[str, chex.PRNGKey]]], DistributionLike
]

But this causes errors like this one across the codebase whenever actor_apply is used:

stoix\systems\mpo\ff_mpo_continuous.py:227: error: Too few arguments

Alternatively, we can use two separate types to distinguish the cases where we use rngs:

ActorApply = Callable[[FrozenDict, Observation], DistributionLike]
ActorRngApply = Callable[
    [
        FrozenDict,
        Observation,
        NamedArg(dict[str, chex.PRNGKey], "rngs"),
    ],
    DistributionLike,
]

This resolves most problems except in the evaluator function get_distribution_act_fn:

def get_distribution_act_fn(
    config: DictConfig,
    actor_apply: Union[ActorApply, ActorRngApply],
    rngs: Optional[Dict[str, chex.PRNGKey]] = None,
) -> ActFn:
    """Get the act_fn for a network that returns a distribution."""

    def act_fn(params: FrozenDict, observation: chex.Array, key: chex.PRNGKey) -> chex.Array:
        """Get the action from the distribution."""
        if rngs is None:
            pi = actor_apply(params, observation)
        else:
            pi = actor_apply(params, observation, rngs=rngs)
        if config.arch.evaluation_greedy:
            action = pi.mode()
        else:
            action = pi.sample(seed=key)
        return action

    return act_fn

It seems that Union[ActorApply, ActorRngApply] is not recognized properly, which leads to:

stoix\evaluator.py:51: error: Missing named argument "rngs"  [call-arg]
stoix\evaluator.py:53: error: Unexpected keyword argument "rngs"  [call-arg]

I'm not super familiar with the best practices in this case, if you have any idea I'd be interested to know!

@EdanToledo
Copy link
Owner

@RPegoud I just made the typing simpler - its not perfect as it doesnt specify that it requires the params and observation but honestly for now im happy with it. Technically, its fine to be general like this - I'll try think of better typing in the future. With that i think the code is done. I'd love us both to just do a final review of all the changes and if we are happy we can merge.

@EdanToledo EdanToledo linked an issue Jun 19, 2024 that may be closed by this pull request
8 tasks
@EdanToledo EdanToledo merged commit a2b453c into EdanToledo:main Jun 20, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] Implement Rainbow DQN
2 participants