Skip to content

Commit

Permalink
Pass max_length_time_axis instead of max_size
Browse files Browse the repository at this point in the history
Makes it so that the warning:

```
Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`
```

will no longer be triggered by legitimate use of `create_flat_buffer` and
`make_prioritised_flat_buffer`.
  • Loading branch information
mickvangelderen committed Nov 8, 2024
1 parent 1baa1b7 commit f3cad12
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions flashbax/buffers/flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ def create_flat_buffer(
)

buffer = make_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
max_size=None,
)

add_fn = buffer.add
Expand Down
4 changes: 2 additions & 2 deletions flashbax/buffers/prioritised_flat_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def make_prioritised_flat_buffer(
)

buffer = make_prioritised_trajectory_buffer(
max_length_time_axis=None, # Unused because max_size is specified
max_length_time_axis=max_length // add_batch_size,
min_length_time_axis=min_length // add_batch_size + 1,
add_batch_size=add_batch_size,
sample_batch_size=sample_batch_size,
sample_sequence_length=2,
period=1,
max_size=max_length,
max_size=None,
priority_exponent=priority_exponent,
device=device,
)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ filterwarnings = [
"error",
"ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
"ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
"ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax",
"ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
]

Expand Down

0 comments on commit f3cad12

Please sign in to comment.