-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/rainbow #86
Feat/rainbow #86
Conversation
I've left some comments, main thing im curious about is whether or not noise is turned off during evaluation. If so then you need to add that into the noisy linear layer. |
Got it, thanks for the feedback. The original paper does not mention deactivating noise during evaluation, however, the Rainbow paper gives some more insights:
I'm not completely sure but this seems to suggest that they maintain some exploration mechanism during evaluation, either ε-greedy with annealing or noisy nets. |
Hey, I'm getting this error when running
|
Aah weird, it seems to most likely be a jax versioning issue. To be honest, I haven't done thorough testing on the docker file and using makefile since I always simply clone and install in conda env. |
I made some minor changes to the docker config and the evaluation function and completed a first run with NoisyDQN on CartPole. Here are the metrics I obtained after a million timesteps and 125 000 updates:
Does it compare with what you generally observe with vanilla DQN? |
Okay great, so I can't confidently say it's similar without running it again (I don't really remember) but it's great to see its learning. Let's maybe test it only a slightly harder env, maybe one of the minatar ones or Jumanji snake 6x6. If we are happy with that then let's move on to the other components of rainbow such as n step and c51 head. |
Sounds good, I'm still struggling to run the training script with GPU support but I should still be able to get a few runs tomorrow. |
I ran it on Jumanji/snake (probably 12x12 by default?) for 125000 updates and it seems to have learned as well:
|
As far as I can tell for n-step C51, we'd need to:
How do you view the API? Would this new version be a new standalone system or should we rather modify the current c51 system to support n_step returns? |
Firstly, I think we should use a trajectory buffer for the n-step returns, this will make our lives much easier and allow us more control over the n steps. Secondly, we should make a new file that will ultimately become rainbow instead of modifying the current c51 imo. |
Great, it's 6x6 in my default configs. Getting 14 is okay. Anything above 32+ is good performance but I imagine this is simply due to hyperparams. Getting 14 is definitely learning so happy with that. |
I see, this is a lot cleaner indeed. C51 seems to have similar issues on Snake as well:
I'll try to take a closer look |
@RPegoud I played around with configs and current rainbow system is working on snake and getting reward. |
@RPegoud okay so i've added prioritised replay and distributional dueling head. Please look over it and see if you have any opinions on how its been done and try verify for correctness. I think we are almost finished and ready for verification and clean up. Last thing to do is to use noisy layers in the dueling network i think and then i believe we are finished with rainbow then we just have to sanity check. |
Great! I should have time to take a look tomorrow or Monday at worst and I'll add the noisy layers 👍 |
Looks like we're really close! Here's what I've noted so far:
|
Great! For the scheduling use an optax scheduler just so we are consistent and allow an easy change of the type of scheduler and for the dueling arch I was thinking we need to change the current way it's constructed. The general design paradigm for other arch's is to have a torso and a head. Let's make the concept of a dueling architecture just a head instead of the torso. That way it's independent of any torso someone wants to use and allows for the shared torso easily. |
Yes I figured |
…ed paper hyperparameters
Here's a first commit with the Rainbow dueling head. I've also included the hyperparameters from the paper but can't run a full training locally. Let me know when you have some time to run an experiment, if those parameters don't work well enough, we can probably revert to the previous ones. PS: as we use CNNs by default with Rainbow, is there a way to override |
Amazing, will try run some tests later and do some review. Regarding env flatten, we can just turn it off by default i dont mind. Lastly, i'm thinking we remove noisy dqn as a system as it isn't really relevant for someone. I imagine someone will either use rainbow or some other variation of dqn that doesnt use noisy layers. |
Cool! Yes that seems reasonable, the training script is just a copy of standard dqn with some added keys, it doesn't justify an additional system. |
@RPegoud I've done some small tests, it seems to work (with some slight changes to hyperparams but thats to be expected since rainbow hyperparameters are hyper optimised for atari). I've done some clean up so I'd say we are pretty much done after we do a scheduler for the importance weight. I realised i didnt answer your one question, regarding tracking the number of training steps, i think we will need to either add an explicit variable to the learner state or somehow get it from the optimizer state. |
@EdanToledo I just added the scheduler, I ended up using: importance_sampling_exponent_scheduler: Callable = optax.linear_schedule(
init_value=config.system.importance_sampling_exponent,
end_value=1,
transition_steps=config.arch.num_updates * config.system.epochs,
transition_begin=0,
) With: step_count = optax.tree_utils.tree_get(opt_states, "count")
importance_sampling_exponent = scheduler_fn(step_count) With this setup, the exponent should reach 1 at the end of training but I haven't been able to run it entirely to check. Let me know if this works as expected. |
okay so im happy with the code, now the last thing we need to do is make it pass the pre-commit linters and typing tests. |
Awesome, I'll look into it tomorrow |
@EdanToledo I've been looking into the typing tests and there seems to be a problem with ActorApply = Callable[
[FrozenDict, Observation, Optional[dict[str, chex.PRNGKey]]], DistributionLike
] But this causes errors like this one across the codebase whenever
Alternatively, we can use two separate types to distinguish the cases where we use rngs: ActorApply = Callable[[FrozenDict, Observation], DistributionLike]
ActorRngApply = Callable[
[
FrozenDict,
Observation,
NamedArg(dict[str, chex.PRNGKey], "rngs"),
],
DistributionLike,
] This resolves most problems except in the evaluator function def get_distribution_act_fn(
config: DictConfig,
actor_apply: Union[ActorApply, ActorRngApply],
rngs: Optional[Dict[str, chex.PRNGKey]] = None,
) -> ActFn:
"""Get the act_fn for a network that returns a distribution."""
def act_fn(params: FrozenDict, observation: chex.Array, key: chex.PRNGKey) -> chex.Array:
"""Get the action from the distribution."""
if rngs is None:
pi = actor_apply(params, observation)
else:
pi = actor_apply(params, observation, rngs=rngs)
if config.arch.evaluation_greedy:
action = pi.mode()
else:
action = pi.sample(seed=key)
return action
return act_fn It seems that
I'm not super familiar with the best practices in this case, if you have any idea I'd be interested to know! |
@RPegoud I just made the typing simpler - its not perfect as it doesnt specify that it requires the params and observation but honestly for now im happy with it. Technically, its fine to be general like this - I'll try think of better typing in the future. With that i think the code is done. I'd love us both to just do a final review of all the changes and if we are happy we can merge. |
What?
Implement Rainbow DQN as defined in #81.
Progress:
NoisyLinear
layer moduleNoisyMLPTorso
moduleComments
NoisyDQN generates a Gaussian noise matrix at inference time, the
NoisyLinear
module currently uses a "noise" rng stream for this purpose:This requires using the following syntax:
Is this the correct way to proceed? If so, what would be the easiest way to pass this additional key through the training script?