Skip to content

Commit

Permalink
fix: HuggingFace save_to_disk takes PathLike type which is defined …
Browse files Browse the repository at this point in the history
…as str, bytes or os.PathLike. imitation.util.parse_path always returned pathlib.Path which is not one of these types. This commit converts pathlib.Path to str before calling the HF fn.
  • Loading branch information
iwishiwasaneagle committed Jul 13, 2024
1 parent a8b079c commit f5cb8a4
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
path: Trajectories are saved to this path.
trajectories: The trajectories to save.
"""
p = util.parse_path(path)
p = str(util.parse_path(path))
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
logging.info(f"Dumped demonstrations to {p}.")

Expand Down
62 changes: 62 additions & 0 deletions tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import gymnasium as gym
import numpy as np
import pytest

from imitation.data import types

SPACES = [
gym.spaces.Discrete(3),
gym.spaces.MultiDiscrete([3, 4]),
gym.spaces.Box(-1, 1, shape=(1,)),
gym.spaces.Box(-1, 1, shape=(2,)),
gym.spaces.Box(-np.inf, np.inf, shape=(2,)),
]
DICT_SPACE = gym.spaces.Dict(
{"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))},
)
LENGTHS = [0, 1, 2, 10]


@pytest.fixture(params=SPACES)
def act_space(request):
return request.param


@pytest.fixture(params=SPACES + [DICT_SPACE])
def obs_space(request):
return request.param


@pytest.fixture(params=LENGTHS)
def length(request):
return request.param


@pytest.fixture
def trajectory(
obs_space: gym.Space,
act_space: gym.Space,
length: int,
) -> types.Trajectory:
"""Fixture to generate trajectory of length `length` iid sampled from spaces."""
if length == 0:
pytest.skip()

raw_obs = [obs_space.sample() for _ in range(length + 1)]
if isinstance(obs_space, gym.spaces.Dict):
obs: types.Observation = types.DictObs.from_obs_list(raw_obs)
else:
obs = np.array(raw_obs)
acts = np.array([act_space.sample() for _ in range(length)])
infos = np.array([{f"key{i}": i} for i in range(length)])
return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True)


@pytest.fixture
def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew:
"""Like `trajectory` but with reward randomly sampled from a Gaussian."""
rews = np.random.randn(len(trajectory))
return types.TrajectoryWithRew(
**types.dataclass_quick_asdict(trajectory),
rews=rews,
)
63 changes: 63 additions & 0 deletions tests/data/test_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for `imitation.data.serialize`."""

import pathlib

import gymnasium as gym
import numpy as np
import pytest

from imitation.data import serialize, types
from imitation.data.types import DictObs


@pytest.fixture
def data_path(tmp_path):
return tmp_path / "data"


@pytest.mark.parametrize("path_type", [str, pathlib.Path])
def test_save_trajectory(data_path, trajectory, path_type):
if isinstance(trajectory.obs, DictObs):
pytest.skip("serialize.save does not yet support DictObs")

serialize.save(path_type(data_path), [trajectory])
assert data_path.exists()


@pytest.mark.parametrize("path_type", [str, pathlib.Path])
def test_save_trajectory_rew(data_path, trajectory_rew, path_type):
if isinstance(trajectory_rew.obs, DictObs):
pytest.skip("serialize.save does not yet support DictObs")
serialize.save(path_type(data_path), [trajectory_rew])
assert data_path.exists()


def test_save_load_trajectory(data_path, trajectory):
if isinstance(trajectory.obs, DictObs):
pytest.skip("serialize.save does not yet support DictObs")
serialize.save(data_path, [trajectory])

reconstructed = list(serialize.load(data_path))
reconstructedi = reconstructed[0]

assert len(reconstructed) == 1
assert np.allclose(reconstructedi.obs, trajectory.obs)
assert np.allclose(reconstructedi.acts, trajectory.acts)
assert np.allclose(reconstructedi.terminal, trajectory.terminal)
assert not hasattr(reconstructedi, "rews")


@pytest.mark.parametrize("load_fn", [serialize.load, serialize.load_with_rewards])
def test_save_load_trajectory_rew(data_path, trajectory_rew, load_fn):
if isinstance(trajectory_rew.obs, DictObs):
pytest.skip("serialize.save does not yet support DictObs")
serialize.save(data_path, [trajectory_rew])

reconstructed = list(load_fn(data_path))
reconstructedi = reconstructed[0]

assert len(reconstructed) == 1
assert np.allclose(reconstructedi.obs, trajectory_rew.obs)
assert np.allclose(reconstructedi.acts, trajectory_rew.acts)
assert np.allclose(reconstructedi.terminal, trajectory_rew.terminal)
assert np.allclose(reconstructedi.rews, trajectory_rew.rews)
48 changes: 0 additions & 48 deletions tests/data/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,13 @@
from imitation.data import serialize, types
from imitation.util import util

SPACES = [
gym.spaces.Discrete(3),
gym.spaces.MultiDiscrete([3, 4]),
gym.spaces.Box(-1, 1, shape=(1,)),
gym.spaces.Box(-1, 1, shape=(2,)),
gym.spaces.Box(-np.inf, np.inf, shape=(2,)),
]
DICT_SPACE = gym.spaces.Dict(
{"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))},
)

OBS_SPACES = SPACES + [DICT_SPACE]
ACT_SPACES = SPACES
LENGTHS = [0, 1, 2, 10]


def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: int, expected_msg: str):
for shape in [(), (length, 1), (length, 2), (length - 1,), (length + 1,)]:
with pytest.raises(ValueError, match=expected_msg):
fn(np.zeros(shape))


@pytest.fixture
def trajectory(
obs_space: gym.Space,
act_space: gym.Space,
length: int,
) -> types.Trajectory:
"""Fixture to generate trajectory of length `length` iid sampled from spaces."""
if length == 0:
pytest.skip()

raw_obs = [obs_space.sample() for _ in range(length + 1)]
if isinstance(obs_space, gym.spaces.Dict):
obs: types.Observation = types.DictObs.from_obs_list(raw_obs)
else:
obs = np.array(raw_obs)
acts = np.array([act_space.sample() for _ in range(length)])
infos = np.array([{f"key{i}": i} for i in range(length)])
return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True)


@pytest.fixture
def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew:
"""Like `trajectory` but with reward randomly sampled from a Gaussian."""
rews = np.random.randn(len(trajectory))
return types.TrajectoryWithRew(
**types.dataclass_quick_asdict(trajectory),
rews=rews,
)


@pytest.fixture
def transitions_min(
obs_space: gym.Space,
Expand Down Expand Up @@ -134,9 +89,6 @@ def pushd(dir_path):
os.chdir(orig_dir)


@pytest.mark.parametrize("obs_space", OBS_SPACES)
@pytest.mark.parametrize("act_space", ACT_SPACES)
@pytest.mark.parametrize("length", LENGTHS)
class TestData:
"""Tests of imitation.util.data.
Expand Down

0 comments on commit f5cb8a4

Please sign in to comment.