Skip to content

Commit

Permalink
Add initial epoch multiplier as a parameter to the PC script.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Jan 11, 2024
1 parent 55aa6eb commit 78553c9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def train_defaults():
transition_oversampling = 1
# fraction of total_comparisons that will be sampled right at the beginning
initial_comparison_frac = 0.1
# factor by which to oversample the number of epochs in the first iteration
initial_epoch_multiplier = 200.0
# fraction of sampled trajectories that will include some random actions
exploration_frac = 0.0
preference_model_kwargs = {}
Expand Down
5 changes: 5 additions & 0 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def train_preference_comparisons(
fragment_length: int,
transition_oversampling: float,
initial_comparison_frac: float,
initial_epoch_multiplier: float,
exploration_frac: float,
trajectory_path: Optional[str],
trajectory_generator_kwargs: Mapping[str, Any],
Expand Down Expand Up @@ -106,6 +107,9 @@ def train_preference_comparisons(
sampled before the rest of training begins (using the randomly initialized
agent). This can be used to pretrain the reward model before the agent
is trained on the learned reward.
initial_epoch_multiplier: before agent training begins, train the reward
model for this many more epochs than usual (on fragments sampled from a
random agent).
exploration_frac: fraction of trajectory samples that will be created using
partially random actions, rather than the current policy. Might be helpful
if the learned policy explores too little and gets stuck with a wrong
Expand Down Expand Up @@ -258,6 +262,7 @@ def train_preference_comparisons(
fragment_length=fragment_length,
transition_oversampling=transition_oversampling,
initial_comparison_frac=initial_comparison_frac,
initial_epoch_multiplier=initial_epoch_multiplier,
custom_logger=custom_logger,
allow_variable_horizon=allow_variable_horizon,
query_schedule=query_schedule,
Expand Down

0 comments on commit 78553c9

Please sign in to comment.