From 4479814a4ac1a4fcd46539adf4ebe224467aa99d Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 11 Dec 2023 16:20:17 +0100 Subject: [PATCH 1/3] Remove FloatReward. Fixes #794 --- tests/algorithms/conftest.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index a453f047d..d28abd823 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -113,20 +113,10 @@ def pendulum_single_venv(rng) -> VecEnv: ) -# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 -# merged and released. -class FloatReward(gym.RewardWrapper): - """Typecasts reward to a float.""" - - def reward(self, reward): - return float(reward) - - @pytest.fixture def multi_obs_venv() -> VecEnv: def make_env(): env = envs.SimpleMultiObsEnv(channel_last=False) - env = FloatReward(env) return RolloutInfoWrapper(env) return DummyVecEnv([make_env, make_env]) From fbcb4064f125d0cfc788cbfbfd472580aa39aa68 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 11 Dec 2023 16:23:25 +0100 Subject: [PATCH 2/3] Bump SB3 version to ensure we have the bug-fix that makes the FloatReward unneeded. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1c2c85af6..1aa407456 100644 --- a/setup.py +++ b/setup.py @@ -203,7 +203,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "rich", "scikit-learn>=0.21.2", "seals~=0.2.1", - "stable-baselines3~=2.0", + "stable-baselines3~=2.2.1", "sacred>=0.8.4", "tensorboard>=1.14", "huggingface_sb3~=3.0", From a55ff9e823d1604f7af3e75102f37bc4d1827b96 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 15 Dec 2023 11:55:06 +0100 Subject: [PATCH 3/3] Remove unused import. --- tests/algorithms/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index d28abd823..4201a26ed 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -1,7 +1,6 @@ """Fixtures common across algorithm tests.""" from typing import Sequence -import gymnasium as gym import pytest from stable_baselines3.common import envs from stable_baselines3.common.policies import BasePolicy