Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Vsys feature: massively parallel domain randomization #458

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

Velythyl
Copy link

@Velythyl Velythyl commented Feb 15, 2024

Hello!

For an unrelated research project, I needed a massively parallel RL environment with domain randomization capabilities. Isaac Sim/Gym/Omniverse fit the bill, but I also needed the simulator to be differentiable w.r.t. each domain randomization parameters.

So I set out to implement DR in brax. This is research code, so it's obviously a little janky and ad-hoc. But I thought maybe the brax community could find this interesting, and perhaps (with a lot of tuning) even merge it into brax main.

Special thanks to this github issue from which I stole some code ;) here

Note that this domain randomization method is more powerful than this. With this code, we can randomize every single simulation step, if we so wish.

The summary of the implementation is simple: we just augment the simulation state to contain sys, thereby allowing every single parallel environment access to its own separate sys. Also, this enables us to resample sys according to some rule (for example, "resample every 50 steps").


Features:

The vsys wrapper allows for a vectorized sys variable that might contain different domain randomization values for each vectorized env
Domain randomization is controlled via a simple yaml file format that describes the path to a domain randomization target. Example:
link:
  inertia:
    mass:
      base: [r, r, r, r, r, r, r]
      min: [-0.5, -0.5,-0.5,-0.5,-0.5,-0.5,-0.5]
      max: [0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
  constraint_ang_damping:
    min: [-1,-1,-1,1,1,1,1]
    max: [2,2,2,1.5,1,1,1]

This randomizes over the 7 links of the robot. For the mass, the base is "r", so the value is "read" from the default value defined in the URDF file. The min-max ranges are both relative to the base, so the current setup randomizes from [r-0.5, r+0.5]. For the damping, no base is given, which defaults to "r". One could also set the base to a float value. Another possible value for the base is "n" ("none"), which disables randomization for this index.

Domain randomization is differentiable (!)

For example, running a simple optax optimizer, we can obtain the true domain randomization parameters in play for a specific timestep.

Known issues:


Again, I don't expect this to be merged as-is. But perhaps the implementation might be interesting to the community, hence the reason for this PR.

@Velythyl Velythyl closed this Feb 15, 2024
@lebrice
Copy link
Contributor

lebrice commented Feb 16, 2024

Hey @Velythyl I've been looking forward to this feature for a while now, thanks a lot for sharing this!
I'm just curious, why did you close the PR?

@Velythyl Velythyl reopened this Feb 16, 2024
@Velythyl
Copy link
Author

@lebrice Hey! Sorry, I realized I had some cleanup to do, and it was way past 5pm so I wanted to go home. I reopened it now.

Copy link
Contributor

@lebrice lebrice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm not a maintainer, this is just a fix for some spacing typos, this is very clean!)

@@ -173,12 +173,12 @@ def reset(self, rng: jax.Array) -> State:
'reward_ctrl': zero,
'reward_run': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done,sys, metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return State(pipeline_state, obs, reward, done,sys, metrics)
return State(pipeline_state, obs, reward, done, sys, metrics)

@@ -218,12 +218,12 @@ def reset(self, rng: jax.Array) -> State:
'x_position': zero,
'x_velocity': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done,sys, metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return State(pipeline_state, obs, reward, done,sys, metrics)
return State(pipeline_state, obs, reward, done, sys, metrics)

@btaba
Copy link
Collaborator

btaba commented Apr 30, 2024

Thanks @Velythyl ! The recent comment made me just realize that maintainers hadn't commented on the PR. There were a few design decisions that went into DomainRandomizationVmapWrapper:

  1. We saw better performance when sys was not added as part of State
  2. We wanted the user to fully define the randomization strategy rather than have a schema. At HEAD, this can be done via the randomization_fn.

The cons of the impl at HEAD are that:

  1. The reset is static and stored in the wrapper, as addressed in this PR.
  2. Simple randomization strategies still require the user to write a randomization_fn

What I think would make sense to merge, is to add a wrapper with the same API as DomainRandomizationVmapWrapper, that passes in_axes and the randomized Sys PyTree values in the State, as discussed in this thread: #446 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants