Skip to content

Commit

Permalink
feat: pass shape to reset, transition and termination functions
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Jan 16, 2024
1 parent 7b0acb6 commit a372c91
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions matrax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
agent_obs=agent_obs,
step_count=state.step_count,
)
timestep = restart(observation=observation)
timestep = restart(observation=observation, shape=self.num_agents)
return state, timestep

def step(
Expand All @@ -122,7 +122,7 @@ def compute_reward(
actions: chex.Array, payoff_matrix_per_agent: chex.Array
) -> chex.Array:
reward_idx = tuple(actions)
return payoff_matrix_per_agent[reward_idx]
return payoff_matrix_per_agent[reward_idx].astype(float)

rewards = jax.vmap(functools.partial(compute_reward, actions))(
self.payoff_matrix
Expand All @@ -143,10 +143,12 @@ def compute_reward(

timestep = jax.lax.cond(
done,
termination,
transition,
rewards,
next_observation,
lambda: termination(
reward=rewards, observation=next_observation, shape=self.num_agents
),
lambda: transition(
reward=rewards, observation=next_observation, shape=self.num_agents
),
)

# create environment state
Expand Down

0 comments on commit a372c91

Please sign in to comment.