Skip to content

Commit

Permalink
Prerequisites for Drake Gym: Import OpenAI Gym; fake-vendor stable_ba…
Browse files Browse the repository at this point in the history
…selines3 (#18440)
  • Loading branch information
ggould-tri authored Dec 21, 2022
1 parent b110da2 commit b6d7485
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tools/workspace/default.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ load("@drake//tools/workspace/glx:repository.bzl", "glx_repository")
load("@drake//tools/workspace/googlebenchmark:repository.bzl", "googlebenchmark_repository") # noqa
load("@drake//tools/workspace/gtest:repository.bzl", "gtest_repository")
load("@drake//tools/workspace/gurobi:repository.bzl", "gurobi_repository")
load("@drake//tools/workspace/gym_py:repository.bzl", "gym_py_repository")
load("@drake//tools/workspace/gz_math_internal:repository.bzl", "gz_math_internal_repository") # noqa
load("@drake//tools/workspace/gz_utils_internal:repository.bzl", "gz_utils_internal_repository") # noqa
load("@drake//tools/workspace/ibex:repository.bzl", "ibex_repository")
Expand Down Expand Up @@ -84,6 +85,7 @@ load("@drake//tools/workspace/scs:repository.bzl", "scs_repository")
load("@drake//tools/workspace/sdformat_internal:repository.bzl", "sdformat_internal_repository") # noqa
load("@drake//tools/workspace/snopt:repository.bzl", "snopt_repository")
load("@drake//tools/workspace/spdlog:repository.bzl", "spdlog_repository")
load("@drake//tools/workspace/stable_baselines3_internal:repository.bzl", "stable_baselines3_internal_repository") # noqa
load("@drake//tools/workspace/statsjs:repository.bzl", "statsjs_repository")
load("@drake//tools/workspace/stduuid:repository.bzl", "stduuid_repository")
load("@drake//tools/workspace/styleguide:repository.bzl", "styleguide_repository") # noqa
Expand Down Expand Up @@ -191,6 +193,8 @@ def add_default_repositories(excludes = [], mirrors = DEFAULT_MIRRORS):
gz_math_internal_repository(name = "gz_math_internal", mirrors = mirrors) # noqa
if "gz_utils_internal" not in excludes:
gz_utils_internal_repository(name = "gz_utils_internal", mirrors = mirrors) # noqa
if "gym_py" not in excludes:
gym_py_repository(name = "gym_py", mirrors = mirrors)
if "ibex" not in excludes:
# N.B. This repository is deprecated for removal on 2023-02-01.
# For details see https://github.com/RobotLocomotion/drake/pull/18156.
Expand Down Expand Up @@ -289,6 +293,8 @@ def add_default_repositories(excludes = [], mirrors = DEFAULT_MIRRORS):
snopt_repository(name = "snopt")
if "spdlog" not in excludes:
spdlog_repository(name = "spdlog", mirrors = mirrors)
if "stable_baselines3_internal" not in excludes:
stable_baselines3_internal_repository(name = "stable_baselines3_internal", mirrors = mirrors) # noqa
if "statsjs" not in excludes:
statsjs_repository(name = "statsjs", mirrors = mirrors)
if "stduuid" not in excludes:
Expand Down
8 changes: 8 additions & 0 deletions tools/workspace/gym_py/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- python -*-

# This file exists to make our directory into a Bazel package, so that our
# neighboring *.bzl file can be loaded elsewhere.

load("//tools/lint:lint.bzl", "add_lint_tests")

add_lint_tests()
4 changes: 4 additions & 0 deletions tools/workspace/gym_py/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
gym
===

This imports the OpenAI Gym RL tool into our build.
12 changes: 12 additions & 0 deletions tools/workspace/gym_py/package.BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- python -*-

load("@drake//tools/skylark:py.bzl", "py_library")

licenses(["notice"]) # MIT

py_library(
name = "gym_py",
srcs = glob(["gym/**/*.py"]),
imports = ["gym"],
visibility = ["//visibility:public"],
)
15 changes: 15 additions & 0 deletions tools/workspace/gym_py/repository.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- python -*-

load("@drake//tools/workspace:github.bzl", "github_archive")

def gym_py_repository(
name,
mirrors = None):
github_archive(
name = name,
repository = "openai/gym",
commit = "v0.21.0",
sha256 = "0efc4ca01fa0d0cd10391b37db0c52e09084fb7a3fd04cdc68e08081acbf4418", # noqa
build_file = ":package.BUILD.bazel",
mirrors = mirrors,
)
14 changes: 14 additions & 0 deletions tools/workspace/stable_baselines3_internal/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- python -*-

# This file exists to make our directory into a Bazel package, so that our
# neighboring *.bzl file can be loaded elsewhere.

load("//tools/lint:lint.bzl", "add_lint_tests")
load("//tools/skylark:drake_py.bzl", "drake_py_unittest")

drake_py_unittest(
name = "stable_baselines3_internal_test",
deps = ["@stable_baselines3_internal//:stable_baselines3"],
)

add_lint_tests()
32 changes: 32 additions & 0 deletions tools/workspace/stable_baselines3_internal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
stable_baselines3_internal
==========================

A local import of `stable_baselines3` that *does not* have its dependencies
(such as `pytorch`). As such, code using this module will have a
`ModuleNotFoundError` error if it attempts to use any functions that depend on
those modules.

The resulting `stable_baselines3` workalike is available via the bazel::

deps = [
"@stable_baselines3_internal//:stable_baselines3",
],

and then, in python::

import stable_baselines3

Code that wants to import the most featureful version of `stable_baselines3`
available (i.e., to use a locally installed `stable_baselines3` if available
and the Drake-provided workalike otherwise) can rely on the presence of the
string `"drake_internal"` in the `__version__` attribute. For instance::

sb3_is_fully_featured = False
try:
import stable_baselines3
if "drake_internal" not in stable_baselines3.__version__:
sb3_is_fully_featured = True
else:
print("stable_baselines3 found, but was drake internal")
except ImportError:
print("stable_baselines3 not found")
11 changes: 11 additions & 0 deletions tools/workspace/stable_baselines3_internal/connection.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py
index f723c71..43fa394 100644
--- stable_baselines3/common/vec_env/subproc_vec_env.py
+++ stable_baselines3/common/vec_env/subproc_vec_env.py
@@ -1,5 +1,6 @@
import multiprocessing as mp
from collections import OrderedDict
+from multiprocessing import connection
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union

import gym
96 changes: 96 additions & 0 deletions tools/workspace/stable_baselines3_internal/no_torch.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py
index d73f5f0..5797dba 100644
--- stable_baselines3/__init__.py
+++ stable_baselines3/__init__.py
@@ -1,13 +1,19 @@
import os

-from stable_baselines3.a2c import A2C
-from stable_baselines3.common.utils import get_system_info
-from stable_baselines3.ddpg import DDPG
-from stable_baselines3.dqn import DQN
-from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
-from stable_baselines3.ppo import PPO
-from stable_baselines3.sac import SAC
-from stable_baselines3.td3 import TD3
+def not_loaded(name):
+ raise ImportError(f"{name} not in Drake's internal stable_baselines3")
+
+def make_not_loaded(name):
+ return lambda *args, **kwargs: not_loaded(name)
+
+A2C = make_not_loaded("A2C")
+get_system_info = make_not_loaded("get_system_info")
+DDPG = make_not_loaded("DDPG")
+DQN = make_not_loaded("DQN")
+HerReplayBuffer = make_not_loaded("HerReplayBuffer")
+PPO = make_not_loaded("PPO")
+SAC = make_not_loaded("SAC")
+TD3 = make_not_loaded("TD3")

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py
index 3b2c502..392d9df 100644
--- stable_baselines3/common/env_checker.py
+++ stable_baselines3/common/env_checker.py
@@ -6,7 +6,13 @@ import numpy as np
from gym import spaces

from stable_baselines3.common.preprocessing import is_image_space_channels_first
-from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
+
+torch_available = False
+try:
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
+ torch_available = True
+except ImportError:
+ print("Torch not available; skipping vector env tests.")


def _is_numpy_array_space(space: spaces.Space) -> bool:
@@ -87,6 +93,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act

def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
+ if not torch_available:
+ return
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = np.array([env.action_space.sample()])
diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py
index 01422aa..0603891 100644
--- stable_baselines3/common/preprocessing.py
+++ stable_baselines3/common/preprocessing.py
@@ -1,10 +1,15 @@
+from ast import Import
import warnings
from typing import Dict, Tuple, Union

import numpy as np
-import torch as th
from gym import spaces
-from torch.nn import functional as F
+
+try:
+ import torch as th
+ from torch.nn import functional as F
+except ImportError:
+ pass # Let any function calls that use torch fail naturally.


def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
@@ -83,10 +88,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->


def preprocess_obs(
- obs: th.Tensor,
+ obs,
observation_space: spaces.Space,
normalize_images: bool = True,
-) -> Union[th.Tensor, Dict[str, th.Tensor]]:
+):
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])

14 changes: 14 additions & 0 deletions tools/workspace/stable_baselines3_internal/package.BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- python -*-

load("@drake//tools/skylark:py.bzl", "py_library")

licenses(["notice"]) # MIT

# Internal vendoring of a few files of stable_baselines3; see README.md
py_library(
name = "stable_baselines3",
srcs = glob(["**/*.py"]),
deps = ["@gym_py"],
data = glob(["**/version.txt"]),
visibility = ["//visibility:public"],
)
19 changes: 19 additions & 0 deletions tools/workspace/stable_baselines3_internal/repository.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- python -*-

load("@drake//tools/workspace:github.bzl", "github_archive")

def stable_baselines3_internal_repository(
name,
mirrors = None):
github_archive(
name = name,
repository = "DLR-RM/stable-baselines3",
commit = "v1.6.0",
sha256 = "f6642fb002adf7ce10087319ea8e9a331d95d26f6558067339f26c84fc588bb6", # noqa
build_file = ":package.BUILD.bazel",
patches = [
":connection.patch",
":no_torch.patch",
],
mirrors = mirrors,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import unittest


class StableBaselines3InternalTest(unittest.TestCase):
"""Test that the methods that we use in drake are able to be loaded
in spite of Drake not having all of SB3's dependencies."""
def test_import(self):
from stable_baselines3.common.env_checker import check_env
help(check_env)

0 comments on commit b6d7485

Please sign in to comment.