Skip to content

Commit

Permalink
Add helpers and instructions on how to share datasets on the huggingf…
Browse files Browse the repository at this point in the history
…ace hub (#705)

* Add trajectories_to_dataset helper.

* Add instructions on how to upload/download trajectories from HuggingFace Datasets Hub.

* Add example on how to use a HuggingFace dataset as a Sequence of Trajectories.

* Fix some issues in the expert model documentation.

* Fix some issues on the example code of the trajectories documentation.

* Remove unused make_dict_from_trajectory function.

* Add no-cover pragmas for rare edge-cases of trajectory loading.

* Add more throughout tests for HuggingFace Datasets backed trajectory sequences.

* Fix bug with sliced access to a TrajectoryDatasetSequence.

* Fix bug with storing the dataset of a TrajectoryDatasetSequence.

* Stop using functools LRU cache because it messes up pickleability.

* Add test for accessing slices of info dicts.
  • Loading branch information
ernestum authored May 25, 2023
1 parent 22758c7 commit bf99117
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 74 deletions.
13 changes: 7 additions & 6 deletions docs/experts/loading-experts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ corresponding policy loader.

.. code-block:: python
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util import util
venv = util.make_vec_env("your-env", n_envs=4)
local_policy = load_policy("ppo", venv, loader_kwargs={"path": "path/to/model.zip"})
venv = util.make_vec_env("your-env", n_envs=4, rng=np.random.default_rng())
local_policy = load_policy("ppo", venv, path="path/to/model.zip")
To load a policy from disk, use either `ppo` or `sac` as the policy type.
The path is specified by `path` in the `loader_kwargs` and it should either point
Expand Down Expand Up @@ -61,15 +62,15 @@ When using the Python API, you also have to specify the environment name as `env

.. code-block:: python
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util import util
venv = util.make_vec_env("your-env", n_envs=4)
venv = util.make_vec_env("your-env", n_envs=4, rng=np.random.default_rng())
remote_policy = load_policy(
"ppo-huggingface",
loader_kwargs=dict(
organization="your-org",
env_name="your-env"
organization="your-org",
env_name="your-env"
)
)
Expand Down
46 changes: 43 additions & 3 deletions docs/tutorials/trajectories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ To generate trajectories from a given policy, run the following command:
.. code-block:: python
import numpy as np
from imitation.data.rollout import rollout
import imitation.data.rollout as rollout
your_trajectories = rollout(
your_trajectories = rollout.rollout(
your_policy,
your_env,
sample_until=make_sample_until(min_episodes=10),
sample_until=rollout.make_sample_until(min_episodes=10),
rng=np.random.default_rng(),
unwrap=False,
)
Storing/Loading Trajectories
Expand All @@ -76,3 +77,42 @@ In the same way you can load trajectories from a HuggingFace Dataset:
Note that some older, now deprecated, trajectory formats are supported by :func:`this loader <imitation.data.serialize.load>`,
but not by the :func:`saver <imitation.data.serialize.save>`.

Sharing Trajectories with the HuggingFace Dataset Hub
-----------------------------------------------------

To share your trajectories with the HuggingFace Dataset Hub, you need to create a HuggingFace account and log in with the HuggingFace CLI:

.. code-block:: bash
$ huggingface-cli login
Then you can upload your trajectories to the HuggingFace Dataset Hub:

.. code-block:: python
from imitation.data.huggingface_utils import trajectories_to_dataset
trajectories_to_dataset(your_trajectories).push_to_hub("your_hf_name/your_dataset_name")
To use a public dataset from the HuggingFace Dataset Hub, you can use the following code:

.. code-block:: python
import datasets
from imitation.data.huggingface_utils import TrajectoryDatasetSequence
your_dataset = datasets.load_dataset("your_hf_name/your_dataset_name")
your_trajectories = TrajectoryDatasetSequence(your_dataset)
The :class:`TrajectoryDatasetSequence <imitation.data.huggingface_utils.TrajectoryDatasetSequence>`
wraps a HuggingFace dataset so it can be used in the same way as a list of trajectories.

For example, you can analyze the dataset with :func:`imitation.data.rollout.rollout_stats` to get the mean return:

.. code-block:: python
from imitation.data.rollout import rollout_stats
stats = rollout_stats(your_trajectories)
print(stats["return_mean"])
92 changes: 33 additions & 59 deletions src/imitation/data/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Helpers to convert between Trajectories and HuggingFace's datasets library."""
import functools
from typing import Any, Dict, Iterable, Sequence, cast
from typing import Any, Dict, Iterable, Optional, Sequence, cast

import datasets
import jsonpickle
Expand All @@ -10,7 +9,7 @@


class TrajectoryDatasetSequence(Sequence[types.Trajectory]):
"""A wrapper to present a HF dataset as a sequence of trajectories.
"""A wrapper to present an HF dataset as a sequence of trajectories.
Converts the dataset to a sequence of trajectories on the fly.
"""
Expand All @@ -34,19 +33,12 @@ def __len__(self) -> int:
def __getitem__(self, idx):

if isinstance(idx, slice):
dataslice = self._dataset[idx]

# Extract the trajectory kwargs from the dataset slice
trajectory_kwargs = [
{key: dataslice[key][i] for key in dataslice}
for i in range(len(dataslice["obs"]))
]

# Ensure that the infos are decoded lazily using jsonpickle
for kwargs in trajectory_kwargs:
kwargs["infos"] = _LazyDecodedList(kwargs["infos"])

return [self._trajectory_class(**kwargs) for kwargs in trajectory_kwargs]
# Note: we could use self._dataset[idx] here and then convert the result of
# that to a series of trajectories, but if we do that, we run into trouble
# with the custom numpy transform that we apply in the constructor.
# The transform is applied to the whole slice, which might contain
# trajectories of different lengths which is not supported by numpy.
return [self[i] for i in range(*idx.indices(len(self)))]
else:
# Extract the trajectory kwargs from the dataset
kwargs = self._dataset[idx]
Expand All @@ -56,6 +48,15 @@ def __getitem__(self, idx):

return self._trajectory_class(**kwargs)

@property
def dataset(self):
"""Return the underlying HF dataset."""
# Note: since we apply the custom numpy transform in the constructor, we remove
# it again before returning the dataset. This ensures that the dataset is
# returned in the original format and can be saved to disk
# (the custom transform can not be saved to disk since it is not pickleable).
return self._dataset.with_transform(None)


class _LazyDecodedList(Sequence[Any]):
"""A wrapper to lazily decode a list of jsonpickled strings.
Expand All @@ -67,56 +68,18 @@ class _LazyDecodedList(Sequence[Any]):

def __init__(self, encoded_list: Sequence[str]):
self._encoded_list = encoded_list
self._decoded_cache: Dict[int, Any] = {}

def __len__(self):
return len(self._encoded_list)

# arbitrary cache size just to put a limit on memory usage
@functools.lru_cache(maxsize=100000)
def __getitem__(self, idx):
if isinstance(idx, slice):
return [jsonpickle.decode(info) for info in self._encoded_list[idx]]
return [self[i] for i in range(*idx.indices(len(self)))]
else:
return jsonpickle.decode(self._encoded_list[idx])


def make_dict_from_trajectory(trajectory: types.Trajectory):
"""Convert a Trajectory to a dict.
The dict has the following fields:
* obs: The observations. Shape: (num_timesteps, obs_dim). dtype: float.
* acts: The actions. Shape: (num_timesteps, act_dim). dtype: float.
* infos: The infos. Shape: (num_timesteps, ). dtype: (jsonpickled) str.
* terminal: The terminal flags. Shape: (num_timesteps, ). dtype: bool.
* rews: The rewards. Shape: (num_timesteps, ). dtype: float. if applicable.
Args:
trajectory: The trajectory to convert.
Returns:
A dict representing the trajectory.
"""
# Replace 'None' values for `infos`` with array of empty dicts
infos = cast(
Sequence[Dict[str, Any]],
trajectory.infos if trajectory.infos is not None else [{}] * len(trajectory),
)

# Encode infos as jsonpickled strings
encoded_infos = [jsonpickle.encode(info) for info in infos]

trajectory_dict = dict(
obs=trajectory.obs,
acts=trajectory.acts,
infos=encoded_infos,
terminal=trajectory.terminal,
)

# Add rewards if applicable
if isinstance(trajectory, types.TrajectoryWithRew):
trajectory_dict["rews"] = trajectory.rews

return trajectory_dict
if idx not in self._decoded_cache:
self._decoded_cache[idx] = jsonpickle.decode(self._encoded_list[idx])
return self._decoded_cache[idx]


def trajectories_to_dict(
Expand Down Expand Up @@ -174,3 +137,14 @@ def trajectories_to_dict(
cast(types.TrajectoryWithRew, traj).rews for traj in trajectories
]
return trajectory_dict


def trajectories_to_dataset(
trajectories: Sequence[types.Trajectory],
info: Optional[datasets.DatasetInfo] = None,
) -> datasets.Dataset:
"""Convert a sequence of trajectories to a HuggingFace dataset."""
if isinstance(trajectories, TrajectoryDatasetSequence):
return trajectories.dataset
else:
return datasets.Dataset.from_dict(trajectories_to_dict(trajectories), info=info)
11 changes: 5 additions & 6 deletions src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,22 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
trajectories: The trajectories to save.
"""
p = util.parse_path(path)
d = datasets.Dataset.from_dict(huggingface_utils.trajectories_to_dict(trajectories))
d.save_to_disk(p)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
logging.info(f"Dumped demonstrations to {p}.")


def load(path: AnyPath) -> Sequence[Trajectory]:
"""Loads a sequence of trajectories saved by `save()` from `path`."""
# Interestingly, np.load will just silently load a normal pickle file when you
# set `allow_pickle=True`. So this call should succeed for both the new compressed
# .npz format and the old pickle based format. To tell the difference we need to
# .npz format and the old pickle based format. To tell the difference, we need to
# look at the type of the resulting object. If it's the new compressed format,
# it should be a Mapping that we need to decode, whereas if it's the old format
# it should be a Mapping that we need to decode, whereas if it's the old format,
# it's just the sequence of trajectories, and we can return it directly.

if os.path.isdir(path): # huggingface datasets format
dataset = datasets.load_from_disk(str(path))
if not isinstance(dataset, datasets.Dataset):
if not isinstance(dataset, datasets.Dataset): # pragma: no cover
raise ValueError(
f"Expected to load a `datasets.Dataset` but got {type(dataset)}",
)
Expand Down Expand Up @@ -66,7 +65,7 @@ def load(path: AnyPath) -> Sequence[Trajectory]:
]
return [TrajectoryWithRew(*args) for args in zip(*fields)]
else:
return [Trajectory(*args) for args in zip(*fields)]
return [Trajectory(*args) for args in zip(*fields)] # pragma: no cover
else: # pragma: no cover
raise ValueError(
f"Expected either an .npz file or a pickled sequence of trajectories; "
Expand Down
Loading

0 comments on commit bf99117

Please sign in to comment.