Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize Dataset Save Not Working #851

Open
alexpalms opened this issue May 21, 2024 · 1 comment
Open

Serialize Dataset Save Not Working #851

alexpalms opened this issue May 21, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@alexpalms
Copy link

Bug description

Hi all, while trying to save locally on my filesystem a trajectories list I discovered that the save method of the serialize module is not working as expected, at list as presented in the docs.

When calling serialize.save("my_path", my_trajectories), the code fails with the following trace:

Downloading a pretrained expert.
Sampling expert transitions.
Traceback (most recent call last):
  File "/home/alexpalms/imitation_learning/imitation_quickstart.py", line 42, in <module>
    transitions = sample_expert_transitions()
  File "/home/alexpalms/imitation_learning/imitation_quickstart.py", line 38, in sample_expert_transitions
    serialize.save(rollouts_path, rollouts)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/imitation/data/serialize.py", line 25, in save
    huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 1515, in save_to_disk
    fs, _ = url_to_fs(dataset_path, **(storage_options or {}))
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/fsspec/core.py", line 383, in url_to_fs
    chain = _un_chain(url, kwargs)
  File "/home/alexpalms/miniconda3/envs/imitation/lib/python3.9/site-packages/fsspec/core.py", line 323, in _un_chain
    if "::" in path

I also found the fix (or better workaround) but since I just came across this lib I am not sure if it is the best way to handle it, as it might hide some compatibility issues with HF library. To fix it, I did the following:
I casted to string the path in this link:

huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)

so from this:

huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)

it became this:

huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))

Steps to reproduce

To reproduce the problem you can execute this code, a customization of your example:

import numpy as np
from imitation.data import rollout, serialize
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from pathlib import Path

rng = np.random.default_rng(0)
env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=rng,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # for computing rollouts
)


def download_expert():
    print("Downloading a pretrained expert.")
    expert = load_policy(
        "ppo-huggingface",
        organization="HumanCompatibleAI",
        env_name="seals-CartPole-v0",
        venv=env,
    )
    return expert


def sample_expert_transitions():
    expert = download_expert()

    print("Sampling expert transitions.")
    rollouts = rollout.rollout(
        expert,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
        rng=rng,
    )
    rollouts_path = Path("./rollouts_path")
    serialize.save(rollouts_path, rollouts)
    return rollout.flatten_trajectories(rollouts)


transitions = sample_expert_transitions()

Environment

  • Operating system and version: Linux Mint 20.3 Una
  • Python version: 3.9.19
  • Output of pip freeze --all:
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
alembic==1.13.1
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.2.2
charset-normalizer==3.3.2
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.2
contourpy==1.2.1
cycler==0.12.1
datasets==2.19.1
dill==0.3.8
docopt==0.6.2
Farama-Notifications==0.0.4
filelock==3.14.0
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.3.1
gitdb==4.0.11
GitPython==3.1.43
greenlet==3.0.3
grpcio==1.63.0
gymnasium==0.29.1
huggingface-hub==0.23.0
huggingface-sb3==3.0
idna==3.7
imitation==1.0.0
importlib_metadata==7.1.0
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.0.4
kiwisolver==1.4.5
Mako==1.3.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
munch==4.0.0
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
optuna==3.6.1
packaging==24.0
pandas==2.2.2
pillow==10.3.0
pip==24.0
protobuf==5.26.1
py-cpuinfo==9.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pygame==2.5.2
Pygments==2.18.0
pyparsing==3.1.2
PyQt5==5.15.10
PyQt5-Qt5==5.15.2
PyQt5-sip==12.13.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.32.1
rich==13.7.1
sacred==0.8.5
scikit-learn==1.4.2
scipy==1.13.0
seals==0.2.1
setuptools==69.5.1
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.30
stable-baselines3==2.1.0
sympy==1.12
tensorboard==2.16.2
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
torch==2.3.0
tqdm==4.66.4
triton==2.3.0
typing_extensions==4.11.0
tzdata==2024.1
urllib3==2.2.1
wasabi==1.1.2
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

Let me know if you want me to create a PR and if you have suggestions to improve that handling.

Looking forward to receive your feedback

@alexpalms alexpalms added the bug Something isn't working label May 21, 2024
@alexpalms
Copy link
Author

As an additional comment, I just noted that you did the same for the load function here:

dataset = datasets.load_from_disk(str(path))

So I suppose that my fix is acceptable? Let me know if you want me to submit a PR for it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant