Skip to content

Commit

Permalink
BC: total_timesteps -> total_steps & remove deprecated methods (#188)
Browse files Browse the repository at this point in the history
* total_timesteps -> total_steps, remove deprecated

* enh dataset standards

* timestep -> step

* remove deprecated func

* fix tests
  • Loading branch information
younik authored Feb 1, 2024
1 parent 28e75bb commit 26fb98e
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 147 deletions.
4 changes: 1 addition & 3 deletions docs/api/data_collector.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ data_collector_callbacks/episode_metadata_callback
```{eval-rst}
.. autofunction:: minari.DataCollector.step
.. autofunction:: minari.DataCollector.reset
.. autofunction:: minari.DataCollector.close
.. autofunction:: minari.DataCollector.create_dataset
.. autofunction:: minari.DataCollector.clear_buffer_to_tmp_file
.. autofunction::minari.DataCollector._add_to_episode_buffer
.. autofunction:: minari.DataCollector.close
```
8 changes: 4 additions & 4 deletions docs/api/minari_dataset/episode_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
The seed used to reset this episode in the Gymnasium API.
.. autoattribute:: minari.EpisodeData.total_timesteps
.. autoattribute:: minari.EpisodeData.total_steps
The number of timesteps contained in this episode.
The number of steps contained in this episode.
.. autoattribute:: minari.EpisodeData.observations
The observations of the environment. The initial and final observations are included meaning that the number
of observations will be increased by one compared to the number of timesteps
of observations will be increased by one compared to the number of steps.
.. autoattribute:: minari.EpisodeData.actions
The actions taken in each episode timestep.
The actions taken in each episode step.
.. autoattribute:: minari.EpisodeData.terminations
Expand Down
16 changes: 8 additions & 8 deletions docs/content/dataset_standards.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,18 +548,18 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac
| ----------------- | ------------------------------------ | ------------------------------------------------------------- |
| `id` | `np.int64` | ID of the episode. |
| `seed` | `np.int64` | Seed used to reset the episode. |
| `total_timesteps` | `np.int64` | Number of timesteps in the episode. |
| `observations` | `np.ndarray`, `list`, `tuple`, `dict` | Observations for each timestep including initial observation. |
| `actions` | `np.ndarray`, `list`, `tuple`, `dict` | Actions for each timestep. |
| `rewards` | `np.ndarray` | Rewards for each timestep. |
| `terminations` | `np.ndarray` | Terminations for each timestep. |
| `truncations` | `np.ndarray` | Truncations for each timestep. |
| `total_steps` | `np.int64` | Number of steps in the episode. |
| `observations` | `np.ndarray`, `list`, `tuple`, `dict` | Observations for each step including initial observation. |
| `actions` | `np.ndarray`, `list`, `tuple`, `dict` | Actions for each step. |
| `rewards` | `np.ndarray` | Rewards for each step. |
| `terminations` | `np.ndarray` | Terminations for each step. |
| `truncations` | `np.ndarray` | Truncations for each step. |
| `infos` | `dict` | A dictionary containing additional information. |

As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used.

## Additional Information Formatting

When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across timesteps.
When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across steps.

Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py.
Given that it is not guaranteed that all Gymnasium environments provide infos at every step, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every step. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py.
4 changes: 2 additions & 2 deletions docs/tutorials/using_datasets/behavioral_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# %%%
# We present here how to perform behavioral cloning on a Minari dataset using `PyTorch <https://pytorch.org/>`_.
# We will start generating the dataset of the expert policy for the `CartPole-v1 <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`_ environment, which is a classic control problem.
# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.
# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step.

# %%
# Imports
Expand Down Expand Up @@ -108,7 +108,7 @@ def collate_fn(batch):
return {
"id": torch.Tensor([x.id for x in batch]),
"seed": torch.Tensor([x.seed for x in batch]),
"total_timesteps": torch.Tensor([x.total_timesteps for x in batch]),
"total_steps": torch.Tensor([x.total_steps for x in batch]),
"observations": torch.nn.utils.rnn.pad_sequence(
[torch.as_tensor(x.observations) for x in batch],
batch_first=True
Expand Down
10 changes: 0 additions & 10 deletions minari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from minari.utils import (
combine_datasets,
create_dataset_from_buffers,
create_dataset_from_collector_env,
get_normalized_score,
split_dataset,
)
Expand All @@ -33,17 +32,8 @@
"load_dataset",
"combine_datasets",
"create_dataset_from_buffers",
"create_dataset_from_collector_env",
"split_dataset",
"get_normalized_score",
]

__version__ = "0.4.3"


def __getattr__(name):
if name == "DataCollectorV0":
from minari.data_collector import DataCollectorV0
return DataCollectorV0
else:
raise ImportError(f"cannot import name '{name}' from '{__name__}' ({__file__})")
8 changes: 0 additions & 8 deletions minari/data_collector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,3 @@


__all__ = ["DataCollector"]


def __getattr__(name):
if name == "DataCollectorV0":
from minari.data_collector.data_collector import DataCollectorV0
return DataCollectorV0
else:
raise ImportError(f"cannot import name '{name}' from '{__name__}' ({__file__})")
21 changes: 3 additions & 18 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import copy
import inspect
import os
import secrets
import shutil
import tempfile
import warnings
from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type, Union

import gymnasium as gym
Expand All @@ -22,6 +20,7 @@
)
from minari.dataset.minari_dataset import MinariDataset
from minari.dataset.minari_storage import MinariStorage
from minari.utils import _generate_dataset_metadata, _generate_dataset_path


# H5Py supports ints up to uint64
Expand All @@ -30,17 +29,6 @@
EpisodeBuffer = Dict[str, Any] # TODO: narrow this down


def __getattr__(name):
if name == "DataCollectorV0":
stacklevel = len(inspect.stack(0))
warnings.warn("DataCollectorV0 is deprecated and will be removed. Use DataCollector instead.", DeprecationWarning, stacklevel=stacklevel)
return DataCollector
elif name == "__path__":
return False # see https://stackoverflow.com/a/60803436
else:
raise ImportError(f"cannot import name '{name}' from '{__name__}' ({__file__})")


class DataCollector(gym.Wrapper):
r"""Gymnasium environment wrapper that collects step data.
Expand Down Expand Up @@ -357,8 +345,6 @@ def create_dataset(
Returns:
MinariDataset
"""
# TODO: move the import to top of the file after removing minari.create_dataset_from_collector_env() in 0.5.0
from minari.utils import _generate_dataset_metadata, _generate_dataset_path
dataset_path = _generate_dataset_path(dataset_id)
metadata: Dict[str, Any] = _generate_dataset_metadata(
dataset_id,
Expand All @@ -375,7 +361,7 @@ def create_dataset(
minari_version,
)

self.save_to_disk(dataset_path, metadata)
self._save_to_disk(dataset_path, metadata)

# will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method

Expand All @@ -384,7 +370,7 @@ def create_dataset(
dataset.storage.update_metadata(metadata)
return dataset

def save_to_disk(
def _save_to_disk(
self, path: str | os.PathLike, dataset_metadata: Dict[str, Any] = {}
):
"""Save all in-memory buffer data and move temporary files to a permanent location in disk.
Expand All @@ -393,7 +379,6 @@ def save_to_disk(
path (str): path to store the dataset, e.g.: '/home/foo/datasets/data'
dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}.
"""
warnings.warn("This method is deprecated and will become private in v0.5.0.", DeprecationWarning, stacklevel=2)
self._validate_buffer()
self._storage.update_episodes(self._buffer)
self._buffer.clear()
Expand Down
4 changes: 2 additions & 2 deletions minari/dataset/episode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class EpisodeData:

id: int
seed: Optional[int]
total_timesteps: int
total_steps: int
observations: Any
actions: Any
rewards: np.ndarray
Expand All @@ -26,7 +26,7 @@ def __repr__(self) -> str:
"EpisodeData("
f"id={repr(self.id)}, "
f"seed={repr(self.seed)}, "
f"total_timesteps={self.total_timesteps}, "
f"total_steps={self.total_steps}, "
f"observations={EpisodeData._repr_space_values(self.observations)}, "
f"actions={EpisodeData._repr_space_values(self.actions)}, "
f"rewards=ndarray of {len(self.rewards)} floats, "
Expand Down
2 changes: 1 addition & 1 deletion minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def total_steps(self) -> int:
else:
self._total_steps = sum(
self.storage.apply(
lambda episode: episode["total_timesteps"],
lambda episode: episode["total_steps"],
episode_indices=self.episode_indices,
)
)
Expand Down
2 changes: 1 addition & 1 deletion minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:

ep_dict = {
"id": ep_group.attrs.get("id"),
"total_timesteps": ep_group.attrs.get("total_steps"),
"total_steps": ep_group.attrs.get("total_steps"),
"seed": seed,
"observations": self._decode_space(
ep_group["observations"], self.observation_space
Expand Down
60 changes: 1 addition & 59 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import Version

from minari import DataCollector
from minari.dataset.minari_dataset import MinariDataset
from minari.dataset.minari_storage import MinariStorage
from minari.storage.datasets_root_dir import get_dataset_path
Expand Down Expand Up @@ -566,63 +565,6 @@ def create_dataset_from_buffers(
return MinariDataset(storage)


def create_dataset_from_collector_env(
dataset_id: str,
collector_env: DataCollector,
eval_env: Optional[str | gym.Env | EnvSpec] = None,
algorithm_name: Optional[str] = None,
author: Optional[str] = None,
author_email: Optional[str] = None,
code_permalink: Optional[str] = None,
ref_min_score: Optional[float] = None,
ref_max_score: Optional[float] = None,
expert_policy: Optional[Callable[[ObsType], ActType]] = None,
num_episodes_average_score: int = 100,
minari_version: Optional[str] = None,
):
"""Create a Minari dataset using the data collected from stepping with a Gymnasium environment wrapped with a `DataCollector` Minari wrapper.
The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows:
``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``.
This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`.
Args:
dataset_id (str): name id to identify Minari dataset
collector_env (DataCollector): Gymnasium environment used to collect the buffer data
buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data
eval_env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) to use for evaluation with the dataset. After loading the dataset, the environment can be recovered as follows: `MinariDataset.recover_environment(eval_env=True).
If None the `env` used to collect the buffer data should be used for evaluation.
algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None.
author (Optional[str], optional): author that generated the dataset. Defaults to None.
author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None.
code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None.
ref_min_score( Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy.
ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in :meth:`minari.get_normalized_score`. Default None.
expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to `num_episodes_average_score`.
`ref_max_score` and `expert_policy` can't be passed at the same time. Default to None
num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100.
minari_version (Optional[str], optional): Minari version specifier compatible with the dataset. If None (default) use the installed Minari version.
Returns:
MinariDataset
"""
warnings.warn("This function is deprecated and will be removed in v0.5.0. Please use DataCollector.create_dataset() instead.", DeprecationWarning, stacklevel=2)
dataset = collector_env.create_dataset(
dataset_id=dataset_id,
eval_env=eval_env,
algorithm_name=algorithm_name,
author=author,
author_email=author_email,
code_permalink=code_permalink,
ref_min_score=ref_min_score,
ref_max_score=ref_max_score,
expert_policy=expert_policy,
num_episodes_average_score=num_episodes_average_score,
minari_version=minari_version,
)
return dataset


def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray:
r"""Normalize undiscounted return of an episode.
Expand Down Expand Up @@ -699,7 +641,7 @@ def get_dataset_spec_dict(
version += f" ({__version__} installed)"

md_dict = {
"Total Timesteps": dataset_spec["total_steps"],
"Total steps": dataset_spec["total_steps"],
"Total Episodes": dataset_spec["total_episodes"],
"Dataset Observation Space": f"`{dataset_observation_space}`",
"Dataset Action Space": f"`{dataset_action_space}`",
Expand Down
30 changes: 15 additions & 15 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,26 +559,26 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):

# verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct
for episode in episodes:
total_steps += episode["total_timesteps"]
total_steps += episode["total_steps"]
_check_space_elem(
episode["observations"],
observation_space,
episode["total_timesteps"] + 1,
episode["total_steps"] + 1,
)
_check_space_elem(episode["actions"], action_space, episode["total_timesteps"])
_check_space_elem(episode["actions"], action_space, episode["total_steps"])

for i in range(episode["total_timesteps"] + 1):
for i in range(episode["total_steps"] + 1):
obs = _reconstuct_obs_or_action_at_index_recursive(
episode["observations"], i
)
assert observation_space.contains(obs)
for i in range(episode["total_timesteps"]):
for i in range(episode["total_steps"]):
action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i)
assert action_space.contains(action)

assert episode["total_timesteps"] == len(episode["rewards"])
assert episode["total_timesteps"] == len(episode["terminations"])
assert episode["total_timesteps"] == len(episode["truncations"])
assert episode["total_steps"] == len(episode["rewards"])
assert episode["total_steps"] == len(episode["terminations"])
assert episode["total_steps"] == len(episode["truncations"])
assert total_steps == data.total_steps


Expand Down Expand Up @@ -707,11 +707,11 @@ def check_episode_data_integrity(
_check_space_elem(
episode.observations,
observation_space,
episode.total_timesteps + 1,
episode.total_steps + 1,
)
_check_space_elem(episode.actions, action_space, episode.total_timesteps)
_check_space_elem(episode.actions, action_space, episode.total_steps)

for i in range(episode.total_timesteps + 1):
for i in range(episode.total_steps + 1):
obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i)
if info_sample is not None:
assert check_infos_equal(
Expand All @@ -721,13 +721,13 @@ def check_episode_data_integrity(

assert observation_space.contains(obs)

for i in range(episode.total_timesteps):
for i in range(episode.total_steps):
action = _reconstuct_obs_or_action_at_index_recursive(episode.actions, i)
assert action_space.contains(action)

assert episode.total_timesteps == len(episode.rewards)
assert episode.total_timesteps == len(episode.terminations)
assert episode.total_timesteps == len(episode.truncations)
assert episode.total_steps == len(episode.rewards)
assert episode.total_steps == len(episode.terminations)
assert episode.total_steps == len(episode.truncations)


def check_infos_equal(info_1: Dict, info_2: Dict) -> bool:
Expand Down
Loading

0 comments on commit 26fb98e

Please sign in to comment.