Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Many randomized environments in parallel #338

Closed
jonasrothfuss opened this issue Apr 20, 2023 · 5 comments
Closed

Many randomized environments in parallel #338

jonasrothfuss opened this issue Apr 20, 2023 · 5 comments

Comments

@jonasrothfuss
Copy link

I would like to use some form of domain randomization with brax environments. Ideally, I would like re-sample the parameters of the envs in every iteration / after every step. So far I only managed to sample random parameters and replace the default parameters in the environment XML. However, loading the 'system' from XML takes very long and makes this approach very inefficient.

Is there a clean and efficient way to change relevant physics parameters (e.g. body mass, size, friction coef, etc.) without the significant overhead of reconstructing the 'system' from scratch?

@btaba
Copy link
Collaborator

btaba commented Apr 21, 2023

Hi @jonasrothfuss

Yes there is a way to sample multiple systems, but we have not polished/released such a thing in v2. Tracing over the System is a high-level intent of the new design of brax. One way to get your project off the ground is to create a vectorized System like so:

def _write_sys(sys, attr, val):
  """"Replaces attributes in sys with val."""
  if not attr:
    return sys
  if len(attr) == 2 and attr[0] == 'geoms':
    geoms = copy.deepcopy(sys.geoms)
    if not hasattr(val, '__iter__'):
      for i, g in enumerate(geoms):
        if not hasattr(g, attr[1]):
          continue
        geoms[i] = g.replace(**{attr[1]: val})
    else:
      sizes = [g.transform.pos.shape[0] for g in geoms]
      g_idx = 0
      for i, g in enumerate(geoms):
        if not hasattr(g, attr[1]):
          continue
        size = sizes[i]
        geoms[i] = g.replace(**{attr[1]: val[g_idx:g_idx + size].T})
        g_idx += size
    return sys.replace(geoms=geoms)
  if len(attr) == 1:
    return sys.replace(**{attr[0]: val})
  return sys.replace(**{attr[0]:
                        _write_sys(getattr(sys, attr[0]), attr[1:], val)})


def set_sys(sys, params: Dict[str, jp.ndarray]):
  """Sets params in the System."""
  for k in params.keys():
    sys = _write_sys(sys, k.split('.'), params[k])
  return sys

def randomize(sys, rng):
  return set_sys(sys, {'link.inertia.mass': sys.link.inertia.mass + jax.random.uniform(rng, shape=(sys.num_links(),))})

rng = jax.random.PRNGKey(0)
rng, *key = jax.random.split(rng, batch_size + 1)
key = jp.reshape(jp.stack(key), (batch_size, 2))
sys_v = jax.vmap(functools.partial(randomize, sys=sys))(key)

I have not fully tested the code, but the intent and rough idea should be there. [1] Create a batch of rng keys, [2] randomize the system over the batch of rng keys, [3] use the vectorized system sys_v to do stuff

We're planning to get a wrapper that does this for you at some point, but still TBD

@jonasrothfuss
Copy link
Author

jonasrothfuss commented Apr 26, 2023

Thanks a lot! The code snippet helped me a lot in getting vectorized random systems to work!

While the masses, friction and elasticity and be easily changed with the set_sys method, it seems that changing the length of links causes problems and changes the system dynamics in unintended ways. I think this is because the transforms (some form of rotation matrices?) depend on the length of the links and get calculated when leading the system XML. Are you aware of any clean workaround to change both the length + transform of a link correctly within the system object without reloading the XML?

@btaba
Copy link
Collaborator

btaba commented Apr 26, 2023

Hi @jonasrothfuss , sure here's a code snippet for capsules which I've been using:

def set_sys_capsules(sys, lengths, radii):
  """Sets the system with new capsule lengths/radii."""
  sys2 = set_sys(sys, {'geoms.length': lengths})
  sys2 = set_sys(sys2, {'geoms.radius': radii})

  # we assume inertia.transform.pos is (0,0,0), as is often the case for
  # capsules

  # get the new joint transform
  cur_len = sys.geoms[1].length[:, None]
  joint_dir = jax.vmap(math.normalize)(sys.link.joint.pos)[0]
  joint_dist = sys.link.joint.pos - 0.5 * cur_len * joint_dir
  joint_transform = 0.5 * lengths[:, None] * joint_dir + joint_dist
  sys2 = set_sys(sys2, {'link.joint.pos': joint_transform})

  # get the new link transform
  parent_idx = jp.array([sys.link_parents])
  sys2 = set_sys(
      sys2,
      {
          'link.transform.pos': -(
              sys2.link.joint.pos
              + joint_dist
              + 0.5 * lengths[parent_idx].T * joint_dir
          )
      },
  )
  return sys2

@jax.jit
def randomize_sys_capsules(
    rng: jp.ndarray,
    sys: base.System,
    min_length: float = 0.0,
    max_length: float = 0.0,
    min_radius: float = 0.0,
    max_radius: float = 0.0,
):
  """Randomizes joint offsets, assume capsule geoms appear in geoms[1]."""
  rng, key1, key2 = jax.random.split(rng, 3)
  length_u = jax.random.uniform(
      key1, shape=(sys.num_links(),), minval=min_length, maxval=max_length
  )
  radius_u = jax.random.uniform(
      key2, shape=(sys.num_links(),), minval=min_radius, maxval=max_radius
  )
  length = length_u + sys.geoms[1].length  # pytype: disable=attribute-error
  radius = radius_u + sys.geoms[1].radius  # pytype: disable=attribute-error
  return set_sys_capsules(sys, length, radius)

@jc-bao
Copy link

jc-bao commented Sep 5, 2023

Does anyone have any ideas on how to effectively implement domain randomization in the new v2 environment?

@btaba
Copy link
Collaborator

btaba commented Sep 7, 2023

Hi @jc-bao , you can check out the test here for an example of domain randomization in v2:

def testTrainDomainRandomize(self):
"""Test PPO with domain randomization."""
def rand_fn(sys, rng):
@jax.vmap
def get_offset(rng):
offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1)
pos = sys.link.transform.pos.at[0].set(offset)
return pos
sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)})
in_axes = jax.tree_map(lambda x: None, sys)
in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0})
return sys_v, in_axes
_, _, _ = ppo.train(
envs.get_environment('inverted_pendulum', backend='spring'),
num_timesteps=2**15,
episode_length=1000,
num_envs=64,
learning_rate=3e-4,
entropy_cost=1e-2,
discounting=0.95,
unroll_length=5,
batch_size=64,
num_minibatches=8,
num_updates_per_batch=4,
normalize_observations=True,
seed=2,
reward_scaling=10,
normalize_advantage=False,
randomization_fn=rand_fn,
)

Notice that rand_fn randomizes the System, and the ppo.train routine takes advantage of this function to wrap the environment:

env = DomainRandomizationVmapWrapper(env, randomization_fn)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants