Skip to content

Commit

Permalink
Pass max_length_time_axis instead of max_size
Browse files Browse the repository at this point in the history
Makes it so that the warning:

```
Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`
```

will no longer be triggered by legitimate use of `create_flat_buffer` and
`make_prioritised_flat_buffer`.
  • Loading branch information
mickvangelderen committed Nov 8, 2024
1 parent 1baa1b7 commit 89cd3a3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 39 deletions.
27 changes: 9 additions & 18 deletions flashbax/buffers/flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,15 @@ def create_flat_buffer(
add_batch_size=add_batch_size,
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Setting max_size dynamically sets the `max_length_time_axis` to "
f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
"This allows one to control exactly how many transitions are stored in the buffer."
"Note that this overrides the `max_length_time_axis` argument.",
)

buffer = make_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
)
buffer = make_trajectory_buffer(
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=None,
)

add_fn = buffer.add

Expand Down
31 changes: 11 additions & 20 deletions flashbax/buffers/prioritised_flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,17 @@ def make_prioritised_flat_buffer(
if not validate_device(device):
device = "cpu"

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Setting max_size dynamically sets the `max_length_time_axis` to "
f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
"This allows one to control exactly how many transitions are stored in the buffer."
"Note that this overrides the `max_length_time_axis` argument.",
)

buffer = make_prioritised_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
priority_exponent=priority_exponent,
device=device,
)
buffer = make_prioritised_trajectory_buffer(
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=None,
priority_exponent=priority_exponent,
device=device,
)

add_fn = buffer.add

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ filterwarnings = [
"error",
"ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
"ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
"ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax",
"ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
]

Expand Down

0 comments on commit 89cd3a3

Please sign in to comment.