Skip to content

Commit

Permalink
model reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
Nishad Gothoskar committed Oct 25, 2023
1 parent 5815b6e commit d4dca88
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def model(array, possible_object_indices, pose_bounds, contact_bounds, all_box_d
faces_child = jnp.array([], dtype=jnp.int32)
parents = jnp.array([], dtype=jnp.int32)
for i in range(array.shape[0]):
parent_obj = uniform_discrete(jnp.arange(-1,array.shape[0] - 1)) @ f"parent_{i}"
parent_face = uniform_discrete(jnp.arange(0,6)) @ f"face_parent_{i}"
child_face = uniform_discrete(jnp.arange(0,6)) @ f"face_child_{i}"
index = uniform_discrete(possible_object_indices) @ f"id_{i}"

pose = uniform_pose(
Expand All @@ -35,9 +38,6 @@ def model(array, possible_object_indices, pose_bounds, contact_bounds, all_box_d
contact_bounds[1]
) @ f"contact_params_{i}"

parent_obj = uniform_discrete(jnp.arange(-1,array.shape[0] - 1)) @ f"parent_{i}"
parent_face = uniform_discrete(jnp.arange(0,6)) @ f"face_parent_{i}"
child_face = uniform_discrete(jnp.arange(0,6)) @ f"face_child_{i}"

indices = jnp.concatenate([indices, jnp.array([index])])
root_poses = jnp.concatenate([root_poses, pose.reshape(1,4,4)])
Expand Down

0 comments on commit d4dca88

Please sign in to comment.