-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prerequisites for Drake Gym: Import OpenAI Gym; fake-vendor stable_ba…
…selines3 (#18440)
- Loading branch information
1 parent
b110da2
commit b6d7485
Showing
12 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# -*- python -*- | ||
|
||
# This file exists to make our directory into a Bazel package, so that our | ||
# neighboring *.bzl file can be loaded elsewhere. | ||
|
||
load("//tools/lint:lint.bzl", "add_lint_tests") | ||
|
||
add_lint_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
gym | ||
=== | ||
|
||
This imports the OpenAI Gym RL tool into our build. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- python -*- | ||
|
||
load("@drake//tools/skylark:py.bzl", "py_library") | ||
|
||
licenses(["notice"]) # MIT | ||
|
||
py_library( | ||
name = "gym_py", | ||
srcs = glob(["gym/**/*.py"]), | ||
imports = ["gym"], | ||
visibility = ["//visibility:public"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# -*- python -*- | ||
|
||
load("@drake//tools/workspace:github.bzl", "github_archive") | ||
|
||
def gym_py_repository( | ||
name, | ||
mirrors = None): | ||
github_archive( | ||
name = name, | ||
repository = "openai/gym", | ||
commit = "v0.21.0", | ||
sha256 = "0efc4ca01fa0d0cd10391b37db0c52e09084fb7a3fd04cdc68e08081acbf4418", # noqa | ||
build_file = ":package.BUILD.bazel", | ||
mirrors = mirrors, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- python -*- | ||
|
||
# This file exists to make our directory into a Bazel package, so that our | ||
# neighboring *.bzl file can be loaded elsewhere. | ||
|
||
load("//tools/lint:lint.bzl", "add_lint_tests") | ||
load("//tools/skylark:drake_py.bzl", "drake_py_unittest") | ||
|
||
drake_py_unittest( | ||
name = "stable_baselines3_internal_test", | ||
deps = ["@stable_baselines3_internal//:stable_baselines3"], | ||
) | ||
|
||
add_lint_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
stable_baselines3_internal | ||
========================== | ||
|
||
A local import of `stable_baselines3` that *does not* have its dependencies | ||
(such as `pytorch`). As such, code using this module will have a | ||
`ModuleNotFoundError` error if it attempts to use any functions that depend on | ||
those modules. | ||
|
||
The resulting `stable_baselines3` workalike is available via the bazel:: | ||
|
||
deps = [ | ||
"@stable_baselines3_internal//:stable_baselines3", | ||
], | ||
|
||
and then, in python:: | ||
|
||
import stable_baselines3 | ||
|
||
Code that wants to import the most featureful version of `stable_baselines3` | ||
available (i.e., to use a locally installed `stable_baselines3` if available | ||
and the Drake-provided workalike otherwise) can rely on the presence of the | ||
string `"drake_internal"` in the `__version__` attribute. For instance:: | ||
|
||
sb3_is_fully_featured = False | ||
try: | ||
import stable_baselines3 | ||
if "drake_internal" not in stable_baselines3.__version__: | ||
sb3_is_fully_featured = True | ||
else: | ||
print("stable_baselines3 found, but was drake internal") | ||
except ImportError: | ||
print("stable_baselines3 not found") |
11 changes: 11 additions & 0 deletions
11
tools/workspace/stable_baselines3_internal/connection.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py | ||
index f723c71..43fa394 100644 | ||
--- stable_baselines3/common/vec_env/subproc_vec_env.py | ||
+++ stable_baselines3/common/vec_env/subproc_vec_env.py | ||
@@ -1,5 +1,6 @@ | ||
import multiprocessing as mp | ||
from collections import OrderedDict | ||
+from multiprocessing import connection | ||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union | ||
|
||
import gym |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py | ||
index d73f5f0..5797dba 100644 | ||
--- stable_baselines3/__init__.py | ||
+++ stable_baselines3/__init__.py | ||
@@ -1,13 +1,19 @@ | ||
import os | ||
|
||
-from stable_baselines3.a2c import A2C | ||
-from stable_baselines3.common.utils import get_system_info | ||
-from stable_baselines3.ddpg import DDPG | ||
-from stable_baselines3.dqn import DQN | ||
-from stable_baselines3.her.her_replay_buffer import HerReplayBuffer | ||
-from stable_baselines3.ppo import PPO | ||
-from stable_baselines3.sac import SAC | ||
-from stable_baselines3.td3 import TD3 | ||
+def not_loaded(name): | ||
+ raise ImportError(f"{name} not in Drake's internal stable_baselines3") | ||
+ | ||
+def make_not_loaded(name): | ||
+ return lambda *args, **kwargs: not_loaded(name) | ||
+ | ||
+A2C = make_not_loaded("A2C") | ||
+get_system_info = make_not_loaded("get_system_info") | ||
+DDPG = make_not_loaded("DDPG") | ||
+DQN = make_not_loaded("DQN") | ||
+HerReplayBuffer = make_not_loaded("HerReplayBuffer") | ||
+PPO = make_not_loaded("PPO") | ||
+SAC = make_not_loaded("SAC") | ||
+TD3 = make_not_loaded("TD3") | ||
|
||
# Read version from file | ||
version_file = os.path.join(os.path.dirname(__file__), "version.txt") | ||
diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py | ||
index 3b2c502..392d9df 100644 | ||
--- stable_baselines3/common/env_checker.py | ||
+++ stable_baselines3/common/env_checker.py | ||
@@ -6,7 +6,13 @@ import numpy as np | ||
from gym import spaces | ||
|
||
from stable_baselines3.common.preprocessing import is_image_space_channels_first | ||
-from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan | ||
+ | ||
+torch_available = False | ||
+try: | ||
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan | ||
+ torch_available = True | ||
+except ImportError: | ||
+ print("Torch not available; skipping vector env tests.") | ||
|
||
|
||
def _is_numpy_array_space(space: spaces.Space) -> bool: | ||
@@ -87,6 +93,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act | ||
|
||
def _check_nan(env: gym.Env) -> None: | ||
"""Check for Inf and NaN using the VecWrapper.""" | ||
+ if not torch_available: | ||
+ return | ||
vec_env = VecCheckNan(DummyVecEnv([lambda: env])) | ||
for _ in range(10): | ||
action = np.array([env.action_space.sample()]) | ||
diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py | ||
index 01422aa..0603891 100644 | ||
--- stable_baselines3/common/preprocessing.py | ||
+++ stable_baselines3/common/preprocessing.py | ||
@@ -1,10 +1,15 @@ | ||
+from ast import Import | ||
import warnings | ||
from typing import Dict, Tuple, Union | ||
|
||
import numpy as np | ||
-import torch as th | ||
from gym import spaces | ||
-from torch.nn import functional as F | ||
+ | ||
+try: | ||
+ import torch as th | ||
+ from torch.nn import functional as F | ||
+except ImportError: | ||
+ pass # Let any function calls that use torch fail naturally. | ||
|
||
|
||
def is_image_space_channels_first(observation_space: spaces.Box) -> bool: | ||
@@ -83,10 +88,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> | ||
|
||
|
||
def preprocess_obs( | ||
- obs: th.Tensor, | ||
+ obs, | ||
observation_space: spaces.Space, | ||
normalize_images: bool = True, | ||
-) -> Union[th.Tensor, Dict[str, th.Tensor]]: | ||
+): | ||
""" | ||
Preprocess observation to be to a neural network. | ||
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) | ||
|
14 changes: 14 additions & 0 deletions
14
tools/workspace/stable_baselines3_internal/package.BUILD.bazel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- python -*- | ||
|
||
load("@drake//tools/skylark:py.bzl", "py_library") | ||
|
||
licenses(["notice"]) # MIT | ||
|
||
# Internal vendoring of a few files of stable_baselines3; see README.md | ||
py_library( | ||
name = "stable_baselines3", | ||
srcs = glob(["**/*.py"]), | ||
deps = ["@gym_py"], | ||
data = glob(["**/version.txt"]), | ||
visibility = ["//visibility:public"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# -*- python -*- | ||
|
||
load("@drake//tools/workspace:github.bzl", "github_archive") | ||
|
||
def stable_baselines3_internal_repository( | ||
name, | ||
mirrors = None): | ||
github_archive( | ||
name = name, | ||
repository = "DLR-RM/stable-baselines3", | ||
commit = "v1.6.0", | ||
sha256 = "f6642fb002adf7ce10087319ea8e9a331d95d26f6558067339f26c84fc588bb6", # noqa | ||
build_file = ":package.BUILD.bazel", | ||
patches = [ | ||
":connection.patch", | ||
":no_torch.patch", | ||
], | ||
mirrors = mirrors, | ||
) |
9 changes: 9 additions & 0 deletions
9
tools/workspace/stable_baselines3_internal/test/stable_baselines3_internal_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import unittest | ||
|
||
|
||
class StableBaselines3InternalTest(unittest.TestCase): | ||
"""Test that the methods that we use in drake are able to be loaded | ||
in spite of Drake not having all of SB3's dependencies.""" | ||
def test_import(self): | ||
from stable_baselines3.common.env_checker import check_env | ||
help(check_env) |