-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Many randomized environments in parallel #338
Comments
Yes there is a way to sample multiple systems, but we have not polished/released such a thing in v2. Tracing over the System is a high-level intent of the new design of brax. One way to get your project off the ground is to create a vectorized System like so: def _write_sys(sys, attr, val):
""""Replaces attributes in sys with val."""
if not attr:
return sys
if len(attr) == 2 and attr[0] == 'geoms':
geoms = copy.deepcopy(sys.geoms)
if not hasattr(val, '__iter__'):
for i, g in enumerate(geoms):
if not hasattr(g, attr[1]):
continue
geoms[i] = g.replace(**{attr[1]: val})
else:
sizes = [g.transform.pos.shape[0] for g in geoms]
g_idx = 0
for i, g in enumerate(geoms):
if not hasattr(g, attr[1]):
continue
size = sizes[i]
geoms[i] = g.replace(**{attr[1]: val[g_idx:g_idx + size].T})
g_idx += size
return sys.replace(geoms=geoms)
if len(attr) == 1:
return sys.replace(**{attr[0]: val})
return sys.replace(**{attr[0]:
_write_sys(getattr(sys, attr[0]), attr[1:], val)})
def set_sys(sys, params: Dict[str, jp.ndarray]):
"""Sets params in the System."""
for k in params.keys():
sys = _write_sys(sys, k.split('.'), params[k])
return sys
def randomize(sys, rng):
return set_sys(sys, {'link.inertia.mass': sys.link.inertia.mass + jax.random.uniform(rng, shape=(sys.num_links(),))})
rng = jax.random.PRNGKey(0)
rng, *key = jax.random.split(rng, batch_size + 1)
key = jp.reshape(jp.stack(key), (batch_size, 2))
sys_v = jax.vmap(functools.partial(randomize, sys=sys))(key) I have not fully tested the code, but the intent and rough idea should be there. [1] Create a batch of rng keys, [2] randomize the system over the batch of rng keys, [3] use the vectorized system We're planning to get a wrapper that does this for you at some point, but still TBD |
Thanks a lot! The code snippet helped me a lot in getting vectorized random systems to work! While the masses, friction and elasticity and be easily changed with the set_sys method, it seems that changing the length of links causes problems and changes the system dynamics in unintended ways. I think this is because the transforms (some form of rotation matrices?) depend on the length of the links and get calculated when leading the system XML. Are you aware of any clean workaround to change both the length + transform of a link correctly within the system object without reloading the XML? |
Hi @jonasrothfuss , sure here's a code snippet for capsules which I've been using: def set_sys_capsules(sys, lengths, radii):
"""Sets the system with new capsule lengths/radii."""
sys2 = set_sys(sys, {'geoms.length': lengths})
sys2 = set_sys(sys2, {'geoms.radius': radii})
# we assume inertia.transform.pos is (0,0,0), as is often the case for
# capsules
# get the new joint transform
cur_len = sys.geoms[1].length[:, None]
joint_dir = jax.vmap(math.normalize)(sys.link.joint.pos)[0]
joint_dist = sys.link.joint.pos - 0.5 * cur_len * joint_dir
joint_transform = 0.5 * lengths[:, None] * joint_dir + joint_dist
sys2 = set_sys(sys2, {'link.joint.pos': joint_transform})
# get the new link transform
parent_idx = jp.array([sys.link_parents])
sys2 = set_sys(
sys2,
{
'link.transform.pos': -(
sys2.link.joint.pos
+ joint_dist
+ 0.5 * lengths[parent_idx].T * joint_dir
)
},
)
return sys2
@jax.jit
def randomize_sys_capsules(
rng: jp.ndarray,
sys: base.System,
min_length: float = 0.0,
max_length: float = 0.0,
min_radius: float = 0.0,
max_radius: float = 0.0,
):
"""Randomizes joint offsets, assume capsule geoms appear in geoms[1]."""
rng, key1, key2 = jax.random.split(rng, 3)
length_u = jax.random.uniform(
key1, shape=(sys.num_links(),), minval=min_length, maxval=max_length
)
radius_u = jax.random.uniform(
key2, shape=(sys.num_links(),), minval=min_radius, maxval=max_radius
)
length = length_u + sys.geoms[1].length # pytype: disable=attribute-error
radius = radius_u + sys.geoms[1].radius # pytype: disable=attribute-error
return set_sys_capsules(sys, length, radius) |
Does anyone have any ideas on how to effectively implement domain randomization in the new |
Hi @jc-bao , you can check out the test here for an example of domain randomization in v2: brax/brax/training/agents/ppo/train_test.py Lines 100 to 132 in 9e14acb
Notice that brax/brax/envs/wrappers/training.py Line 53 in 9e14acb
|
I would like to use some form of domain randomization with brax environments. Ideally, I would like re-sample the parameters of the envs in every iteration / after every step. So far I only managed to sample random parameters and replace the default parameters in the environment XML. However, loading the 'system' from XML takes very long and makes this approach very inefficient.
Is there a clean and efficient way to change relevant physics parameters (e.g. body mass, size, friction coef, etc.) without the significant overhead of reconstructing the 'system' from scratch?
The text was updated successfully, but these errors were encountered: