diff --git a/flashbax/__init__.py b/flashbax/__init__.py index f685c37..576f501 100644 --- a/flashbax/__init__.py +++ b/flashbax/__init__.py @@ -19,10 +19,12 @@ make_flat_buffer, make_item_buffer, make_prioritised_flat_buffer, + make_prioritised_item_buffer, make_prioritised_trajectory_buffer, make_trajectory_buffer, make_trajectory_queue, prioritised_flat_buffer, + prioritised_item_buffer, prioritised_trajectory_buffer, trajectory_buffer, trajectory_queue, diff --git a/flashbax/buffers/__init__.py b/flashbax/buffers/__init__.py index a458d89..1e0096f 100644 --- a/flashbax/buffers/__init__.py +++ b/flashbax/buffers/__init__.py @@ -14,6 +14,7 @@ from flashbax.buffers.flat_buffer import make_flat_buffer from flashbax.buffers.item_buffer import make_item_buffer from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer +from flashbax.buffers.prioritised_item_buffer import make_prioritised_item_buffer from flashbax.buffers.prioritised_trajectory_buffer import ( make_prioritised_trajectory_buffer, ) diff --git a/flashbax/buffers/prioritised_item_buffer.py b/flashbax/buffers/prioritised_item_buffer.py new file mode 100644 index 0000000..a78e304 --- /dev/null +++ b/flashbax/buffers/prioritised_item_buffer.py @@ -0,0 +1,154 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from chex import PRNGKey + +from flashbax import utils +from flashbax.buffers.item_buffer import validate_item_buffer_args +from flashbax.buffers.prioritised_flat_buffer import validate_priority_exponent +from flashbax.buffers.prioritised_trajectory_buffer import ( + PrioritisedTrajectoryBuffer, + PrioritisedTrajectoryBufferSample, + PrioritisedTrajectoryBufferState, + make_prioritised_trajectory_buffer, + validate_device, +) +from flashbax.buffers.trajectory_buffer import Experience +from flashbax.utils import add_dim_to_args + + +def create_prioritised_item_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool, + add_batches: bool, + priority_exponent: float, + device: str, +) -> PrioritisedTrajectoryBuffer: + """Creates a prioritised trajectory buffer that acts as an independent item buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single items are being added each time add + is called. Defaults to False. + add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. + If False, single items (or single sequences of items) are being added each time add + is called. Defaults to False. + priority_exponent: Priority exponent for sampling. Equivalent to \alpha in the PER paper. + device: "tpu", "gpu" or "cpu". Depending on chosen device, more optimal functions will be + used to perform the buffer operations. + + Returns: + The buffer.""" + + validate_item_buffer_args( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + ) + + validate_priority_exponent(priority_exponent) + if not validate_device(device): + device = "cpu" + + buffer = make_prioritised_trajectory_buffer( + max_length_time_axis=max_length, + min_length_time_axis=min_length, + add_batch_size=1, + sample_batch_size=sample_batch_size, + sample_sequence_length=1, + period=1, + priority_exponent=priority_exponent, + device=device, + ) + + def add_fn( + state: PrioritisedTrajectoryBufferState, batch: Experience + ) -> PrioritisedTrajectoryBufferState[Experience]: + """Flattens a batch to add items along single time axis.""" + batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2) + flattened_batch = jax.tree_map( + lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch + ) + return buffer.add(state, flattened_batch) + + if not add_batches: + add_fn = add_dim_to_args( + add_fn, axis=0, starting_arg_index=1, ending_arg_index=2 + ) + + if not add_sequences: + axis = 1 - int(not add_batches) # 1 if add_batches else 0 + add_fn = add_dim_to_args( + add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2 + ) + + def sample_fn( + state: PrioritisedTrajectoryBufferState, rng_key: PRNGKey + ) -> PrioritisedTrajectoryBufferSample[Experience]: + """Samples a batch of items from the buffer.""" + sampled_batch = buffer.sample(state, rng_key) + priorities = sampled_batch.priorities + indices = sampled_batch.indices + sampled_batch = sampled_batch.experience + sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch) + return PrioritisedTrajectoryBufferSample( + experience=sampled_batch, indices=indices, priorities=priorities + ) + + return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore + + +def make_prioritised_item_buffer( + max_length: int, + min_length: int, + sample_batch_size: int, + add_sequences: bool = False, + add_batches: bool = False, + priority_exponent: float = 0.6, + device: str = "cpu", +) -> PrioritisedTrajectoryBuffer: + """Makes a prioritised trajectory buffer act as a independent item buffer. + + Args: + max_length (int): The maximum length of the buffer. + min_length (int): The minimum length of the buffer. + sample_batch_size (int): The batch size of the samples. + add_sequences (Optional[bool], optional): Whether data is being added in sequences + to the buffer. If False, single items are being added each time add + is called. Defaults to False. + add_batches: (Optional[bool], optional): Whether adding data in batches to the buffer. + If False, single transitions or single sequences are being added each time add + is called. Defaults to False. + priority_exponent: Priority exponent for sampling. Equivalent to \alpha in the PER paper. + device: "tpu", "gpu" or "cpu". Depending on chosen device, more optimal functions will be + used to perform the buffer operations. + + Returns: + The buffer.""" + + return create_prioritised_item_buffer( + max_length=max_length, + min_length=min_length, + sample_batch_size=sample_batch_size, + add_sequences=add_sequences, + add_batches=add_batches, + priority_exponent=priority_exponent, + device=device, + ) diff --git a/flashbax/buffers/prioritised_item_buffer_test.py b/flashbax/buffers/prioritised_item_buffer_test.py new file mode 100644 index 0000000..c7a4207 --- /dev/null +++ b/flashbax/buffers/prioritised_item_buffer_test.py @@ -0,0 +1,317 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from copy import deepcopy + +import chex +import jax +import jax.numpy as jnp +import pytest + +from flashbax.buffers import prioritised_item_buffer, sum_tree +from flashbax.buffers.conftest import get_fake_batch +from flashbax.conftest import _DEVICE_COUNT_MOCK + + +@pytest.fixture() +def priority_exponent() -> float: + return 0.6 + + +def test_add_and_can_sample( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + add_batch_size: int, +) -> None: + """Check the `add` function by filling the buffer all + the way to the max_length and checking that it produces the expected behaviour . + """ + fake_batch = get_fake_batch(fake_transition, add_batch_size) + + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, 4, False, True + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_batches_to_fill = max_length // add_batch_size + n_batches_to_sample = min_length // add_batch_size + + for i in range(n_batches_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == (((i + 1) * add_batch_size) % max_length) + + # Check that the `can_sample` function behavior is correct. + is_ready_to_sample = buffer.can_sample(state) + if i < (n_batches_to_sample): + assert not is_ready_to_sample + else: + assert is_ready_to_sample + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_sample( + fake_transition: chex.ArrayTree, + max_length: int, + rng_key: chex.PRNGKey, +) -> None: + """Test the random sampling from the buffer.""" + + min_length = 40 + sample_batch_size = 20 + rng_key1, rng_key2 = jax.random.split(rng_key) + + # Fill buffer to the point that we can sample + fake_batch = get_fake_batch(fake_transition, sample_batch_size) + + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, sample_batch_size, False, True + ) + state = buffer.init(fake_transition) + + # Add two batches of items. + state = buffer.add(state, fake_batch) + assert not buffer.can_sample(state) + state = buffer.add(state, fake_batch) + assert buffer.can_sample(state) + + # Sample from the buffer with different keys and check it gives us different batches. + batch1 = buffer.sample(state, rng_key1) + batch2 = buffer.sample(state, rng_key2) + + # Check that all the corresponding priorities are greater than 0. + assert (batch1.priorities > 0).all() + assert (batch2.priorities > 0).all() + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(batch1.experience, batch2.experience) + + # Check dtypes are correct. + chex.assert_trees_all_equal_dtypes( + fake_transition, + batch1.experience, + batch2.experience, + ) + + +def test_adjust_priorities( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + rng_key: chex.PRNGKey, + sample_batch_size: int, + priority_exponent: float, +) -> None: + """Test the adjustment of priorities in the buffer.""" + rng_key1, rng_key2 = jax.random.split(rng_key) + + add_batch_size = int(min_length + 10) + # Fill buffer to the point that we can sample. + fake_batch = get_fake_batch(fake_transition, add_batch_size) + buffer = buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, sample_batch_size, False, True, priority_exponent + ) + state = buffer.init(fake_transition) + + state = buffer.add(state, fake_batch) + + # Sample from the buffer. + batch = buffer.sample(state, rng_key1) + + # Create fake new priorities, and apply the adjustment. + new_priorities = jnp.ones_like(batch.priorities) + 10007 + state = buffer.set_priorities(state, batch.indices, new_priorities) + + # Check that this results in the correct changes to the state. + assert ( + state.priority_state.max_recorded_priority + == jnp.max(new_priorities) ** priority_exponent + ) + assert ( + sum_tree.get(state.priority_state, batch.indices) + == new_priorities**priority_exponent + ).all() + + +def test_add_batch_size_none( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, +): + # create a fake batch and ensure there is no batch dimension + fake_batch = jax.tree_map( + lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) + ) + + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, 4, False, False + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_batches_to_fill = max_length + n_batches_to_sample = min_length + + for i in range(n_batches_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == ((i + 1) % (max_length)) + + # Check that the `can_sample` function behavior is correct. + is_ready_to_sample = buffer.can_sample(state) + if i < (n_batches_to_sample - 1): + assert not is_ready_to_sample + else: + assert is_ready_to_sample + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + chex.assert_trees_all_close(state.indices, init_state.indices) + + +def test_add_sequences( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, +): + add_sequence_size = 5 + # create a fake sequence and ensure there is no batch dimension + fake_batch = jax.tree_map( + lambda x: x.repeat(add_sequence_size, axis=0), + get_fake_batch(fake_transition, 1), + ) + assert fake_batch["obs"].shape[0] == add_sequence_size + + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, 4, add_sequences=True, add_batches=False + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_sequences_to_fill = (max_length // add_sequence_size) + 1 + + for i in range(n_sequences_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == (((i + 1) * add_sequence_size) % (max_length)) + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_add_sequences_and_batches( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + add_batch_size: int, +): + add_sequence_size = 5 + # create a fake batch and sequence + fake_batch = jax.tree_map( + lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), + get_fake_batch(fake_transition, add_batch_size), + ) + assert fake_batch["obs"].shape[:2] == (add_batch_size, add_sequence_size) + + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, 4, add_sequences=True, add_batches=True + ) + state = buffer.init(fake_transition) + + init_state = deepcopy(state) # Save for later checks. + + n_sequences_to_fill = (max_length // (add_batch_size * add_sequence_size)) + 1 + + for i in range(n_sequences_to_fill): + assert not state.is_full + state = buffer.add(state, fake_batch) + assert state.current_index == ( + ((i + 1) * add_sequence_size * add_batch_size) % max_length + ) + + assert state.is_full + + # Check that the transitions have been updated. + with pytest.raises(AssertionError): + chex.assert_trees_all_close(state.experience, init_state.experience) + + +def test_item_replay_buffer_does_not_smoke( + fake_transition: chex.ArrayTree, + min_length: int, + max_length: int, + rng_key: chex.PRNGKey, + sample_batch_size: int, + priority_exponent: float, +): + """Create the itemBuffer NamedTuple, and check that it is pmap-able and does not smoke.""" + + add_batch_size = int(min_length + 5) + buffer = prioritised_item_buffer.make_prioritised_item_buffer( + max_length, min_length, sample_batch_size, False, True, priority_exponent + ) + + # Initialise the buffer's state. + fake_transition_per_device = jax.tree_map( + lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition + ) + state = jax.pmap(buffer.init)(fake_transition_per_device) + + # Now fill the buffer above its minimum length. + + fake_batch = jax.pmap(get_fake_batch, static_broadcasted_argnums=1)( + fake_transition_per_device, add_batch_size + ) + # Add two items thereby giving a single transition. + state = jax.pmap(buffer.add)(state, fake_batch) + state = jax.pmap(buffer.add)(state, fake_batch) + assert buffer.can_sample(state).all() + + # Sample from the buffer. + # Sample from the buffer. + rng_key_per_device = jax.random.split(rng_key, _DEVICE_COUNT_MOCK) + batch = jax.pmap(buffer.sample)(state, rng_key_per_device) + chex.assert_tree_shape_prefix(batch, (_DEVICE_COUNT_MOCK, sample_batch_size)) + + # Adjust priorities. + new_priorities = jax.pmap(jnp.ones_like)(batch.priorities) + 10007 + state = jax.pmap(buffer.set_priorities)(state, batch.indices, new_priorities) + + # Check that the priority adjustment produces the correct changes to the state. + assert ( + state.priority_state.max_recorded_priority + == jnp.max(new_priorities) ** priority_exponent + ).all() + assert ( + jax.pmap(sum_tree.get)(state.priority_state, batch.indices) + == new_priorities**priority_exponent + ).all()