Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][RLlib] "Samples must consistently provide or omit max_seq_len" with PolicyServerInput/PolicyClient and RNNs #23639

Closed
1 of 2 tasks
Fabien-Couthouis opened this issue Mar 31, 2022 · 3 comments · Fixed by #23740
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-client-server Issue related to RLlib's client/server API.

Comments

@Fabien-Couthouis
Copy link
Contributor

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

RLlib

Issue Severity

High: It blocks me to complete my task.

What happened + What you expected to happen

There is a recurrent issue with PolicyClient/Server and (at least) IMPALA + RNN model & R2D2. The error does not appear on PPO:

Traceback (most recent call last):
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "D:\BGE2\combatbot\scripts\cli.py", line 105, in <module>
    cli(obj=attrdict())
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 1137, in __call__
    return self.main(*args, **kwargs)
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 1062, in main
    rv = self.invoke(ctx)
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 1668, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 1668, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "C:\Users\<username>\Miniconda3\envs\bge-rllib\lib\site-packages\click\core.py", line 763, in invoke
    return __callback(*args, **kwargs)
  File "D:\BGE2\combatbot\scripts\distributed_cli_group.py", line 235, in server
    launch_server(address, port)
  File "D:\BGE2\combatbot\training\distributed\launch_server.py", line 157, in launch_server
    result = trainer.train()
  File "<site_packages_path>\ray\tune\trainable.py", line 319, in train
    result = self.step()
  File "<site_packages_path>\ray\rllib\agents\trainer.py", line 984, in step
    raise e
  File "<site_packages_path>\ray\rllib\agents\trainer.py", line 965, in step
    step_attempt_results = self.step_attempt()
  File "<site_packages_path>\ray\rllib\agents\trainer.py", line 1044, in step_attempt
    step_results = self._exec_plan_or_training_iteration_fn()
  File "<site_packages_path>\ray\rllib\agents\trainer.py", line 2032, in _exec_plan_or_training_iteration_fn
    results = next(self.train_exec_impl)
  File "<site_packages_path>\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 843, in apply_filter
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 1075, in build_union
    item = next(it)
  File "<site_packages_path>\ray\util\iter.py", line 756, in __next__
    return next(self.built_iterator)
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 876, in apply_flatten
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 876, in apply_flatten
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\util\iter.py", line 783, in apply_foreach
    for item in it:
  File "<site_packages_path>\ray\rllib\execution\rollout_ops.py", line 263, in sampler
    yield workers.local_worker().sample()
  File "<site_packages_path>\ray\rllib\evaluation\rollout_worker.py", line 780, in sample
    batch = batches[0].concat_samples(batches) if len(batches) > 1 else \
  File "<site_packages_path>\ray\rllib\policy\sample_batch.py", line 215, in concat_samples
    raise ValueError(
ValueError: Samples must consistently provide or omit max_seq_len

This check in ray\rllib\policy\sample_batch.py is causing the issue:

if (s.max_seq_len is None or max_seq_len is None) and s.max_seq_len != max_seq_len:
       raise ValueError(
          "Samples must consistently provide or omit max_seq_len"
       )

because max_seq_len is None (s.max_seq_len is fine).
Removing the check does not bring any issue and training runs.

Versions / Dependencies

Windows 10
Python 3.9.7
ray==1.11.0 (should also appear on master branch)
The issue appears with ray > 1.9.2.

Reproduction script

I modified the cartpole_server script provided in the examples to run with lstm & other algorithms. I also changed the hyperparameters for fast reproduction.

cartpole_server.py

#!/usr/bin/env python
"""
Example of running an RLlib policy server, allowing connections from
external environment running clients. The server listens on
(a simple CartPole env
in this case) against an RLlib policy server listening on one or more
HTTP-speaking ports. See `cartpole_client.py` in this same directory for how
to start any number of clients (after this server has been started).
This script will not create any actual env to illustrate that RLlib can
run w/o needing an internalized environment.
Setup:
1) Start this server:
    $ python cartpole_server.py --num-workers --[other options]
      Use --help for help.
2) Run n policy clients:
    See `cartpole_client.py` on how to do this.
The `num-workers` setting will allow you to distribute the incoming feed over n
listen sockets (in this example, between 9900 and 990n with n=worker_idx-1).
You may connect more than one policy client to any open listen port.
"""

import argparse
import os

import gym
import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.impala import ImpalaTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print

SERVER_ADDRESS = "localhost"
# In this example, the user can run the policy server with
# n workers, opening up listen ports 9900 - 990n (n = num_workers - 1)
# to each of which different clients may connect.
SERVER_BASE_PORT = 9900  # + worker-idx - 1

CHECKPOINT_FILE = "last_checkpoint_{}.out"


def get_cli_args():
    """Create CLI parser and return parsed arguments"""
    parser = argparse.ArgumentParser()

    # Example-specific args.
    parser.add_argument(
        "--port",
        type=int,
        default=SERVER_BASE_PORT,
        help="The base-port to use (on localhost). " f"Default is {SERVER_BASE_PORT}.",
    )
    parser.add_argument(
        "--callbacks-verbose",
        action="store_true",
        help="Activates info-messages for different events on "
        "server/client (episode steps, postprocessing, etc..).",
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=0,
        help="The number of workers to use. Each worker will create "
        "its own listening socket for incoming experiences.",
    )
    parser.add_argument(
        "--no-restore",
        action="store_true",
        help="Do not restore from a previously saved checkpoint (location of "
        "which is saved in `last_checkpoint_[algo-name].out`).",
    )

    # General args.
    parser.add_argument(
        "--run",
        default="IMPALA",
        help="The RLlib-registered algorithm to use.",
    )
    parser.add_argument("--num-cpus", type=int, default=3)
    parser.add_argument(
        "--framework",
        choices=["tf", "tf2", "tfe", "torch"],
        default="tf",
        help="The DL framework specifier.",
    )
    parser.add_argument(
        "--stop-iters", type=int, default=200, help="Number of iterations to train."
    )
    parser.add_argument(
        "--stop-timesteps",
        type=int,
        default=500000,
        help="Number of timesteps to train.",
    )
    parser.add_argument(
        "--stop-reward",
        type=float,
        default=80.0,
        help="Reward at which we stop training.",
    )
    parser.add_argument(
        "--as-test",
        action="store_true",
        help="Whether this script should be run as a test: --stop-reward must "
        "be achieved within --stop-timesteps AND --stop-iters.",
    )
    parser.add_argument(
        "--no-tune",
        action="store_true",
        help="Run without Tune using a manual train loop instead. Here,"
        "there is no TensorBoard support.",
    )
    parser.add_argument(
        "--local-mode",
        action="store_true",
        help="Init Ray in local mode for easier debugging.",
    )

    args = parser.parse_args()
    print(f"Running with following CLI args: {args}")
    return args


if __name__ == "__main__":
    args = get_cli_args()
    ray.init()

    # `InputReader` generator (returns None if no input reader is needed on
    # the respective worker).
    def _input(ioctx):
        # We are remote worker or we are local worker with num_workers=0:
        # Create a PolicyServerInput.
        if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
            return PolicyServerInput(
                ioctx,
                SERVER_ADDRESS,
                args.port + ioctx.worker_index - (1 if ioctx.worker_index > 0 else 0),
            )
        # No InputReader (PolicyServerInput) needed.
        else:
            return None

    # Trainer config. Note that this config is sent to the client only in case
    # the client needs to create its own policy copy for local inference.
    config = {
        # Indicate that the Trainer we setup here doesn't need an actual env.
        # Allow spaces to be determined by user (see below).
        "env": None,
        # TODO: (sven) make these settings unnecessary and get the information
        #  about the env spaces from the client.
        "observation_space": gym.spaces.Box(float("-inf"), float("inf"), (4,)),
        "action_space": gym.spaces.Discrete(2),
        # Use the `PolicyServerInput` to generate experiences.
        "input": _input,
        # Use n worker processes to listen on different ports.
        "num_workers": args.num_workers,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],
        # Create a "chatty" client/server or not.
        "callbacks": MyCallbacks if args.callbacks_verbose else None,
        # DL framework to use.
        "framework": args.framework,
        # Set to INFO so we'll see the server's actual address:port.
        "log_level": "INFO",
        "model":{}
    }

    # DQN.
    if args.run == "DQN" or args.run == "APEX" or args.run == "R2D2":
        # Example of using DQN (supports off-policy actions).
        config.update(
            {
                "learning_starts": 100,
                "timesteps_per_iteration": 200,
                "n_step": 3,
                "rollout_fragment_length": 4,
                "train_batch_size": 8

            }
        )
        config["model"] = {
            "fcnet_hiddens": [64],
            "fcnet_activation": "linear",
        }
    elif args.run == "IMPALA":
        config.update({
        "batch_mode": "complete_episodes",
        "rollout_fragment_length": 8,
        "train_batch_size": 32,
        "minibatch_buffer_size": 10,
        "num_sgd_iter": 1,
    })

    # PPO.
    else:
        # Example of using PPO (does NOT support off-policy actions).
        config.update(
            {
                "rollout_fragment_length": 4,
                "train_batch_size": 8,
                "sgd_minibatch_size": 4
            }
        )
    # This line causes the issue
    config["model"]["use_lstm"] = True

    checkpoint_path = CHECKPOINT_FILE.format(args.run)
    # Attempt to restore from checkpoint, if possible.
    if not args.no_restore and os.path.exists(checkpoint_path):
        print("",args.no_restore)
        checkpoint_path = open(checkpoint_path).read()
    else:
        checkpoint_path = None

    # Manual training loop (no Ray tune).
    if args.no_tune:
        print("**Launching without TUne...")
        TrainerClass = get_trainer_class(args.run)
        trainer = TrainerClass(config=config)

        if checkpoint_path:
            print("Restoring from checkpoint path", checkpoint_path)
            trainer.restore(checkpoint_path)

        # Serving and training loop.
        ts = 0
        for _ in range(args.stop_iters):
            results = trainer.train()
            print(pretty_print(results))
            # checkpoint = trainer.save()
            # print("Last checkpoint", checkpoint)
            # with open(checkpoint_path, "w") as f:
            #     f.write(checkpoint)
            if (
                results["episode_reward_mean"] >= args.stop_reward
                or ts >= args.stop_timesteps
            ):
                break
            ts += results["timesteps_total"]

    # Run with Tune for auto env and trainer creation and TensorBoard.
    else:
        print("**Launching with TUne...")

        stop = {
            "training_iteration": args.stop_iters,
            "timesteps_total": args.stop_timesteps,
            "episode_reward_mean": args.stop_reward,
        }

        tune.run(args.run, config=config, stop=stop, verbose=2) #, restore=checkpoint_path)

cartpole_client.py:

#!/usr/bin/env python
"""
Example of running an external simulator (a simple CartPole env
in this case) against an RLlib policy server listening on one or more
HTTP-speaking port(s). See `cartpole_server.py` in this same directory for
how to start this server.
This script will only create one single env altogether to illustrate
that RLlib can run w/o needing an internalized environment.
Setup:
1) Start the policy server:
    See `cartpole_server.py` on how to do this.
2) Run this client:
    $ python cartpole_client.py --inference-mode=local|remote --[other options]
      Use --help for help.
In "local" inference-mode, the action computations are performed
inside the PolicyClient used in this script w/o sending an HTTP request
to the server. This reduces network communication overhead, but requires
the PolicyClient to create its own RolloutWorker (+Policy) based on
the server's config. The PolicyClient will retrieve this config automatically.
You do not need to define the RLlib config dict here!
In "remote" inference mode, the PolicyClient will send action requests to the
server and not compute its own actions locally. The server then performs the
inference forward pass and returns the action to the client.
In either case, the user of PolicyClient must:
- Declare new episodes and finished episodes to the PolicyClient.
- Log rewards to the PolicyClient.
- Call `get_action` to receive an action from the PolicyClient (whether it'd be
  computed locally or remotely).
- Besides `get_action`, the user may let the PolicyClient know about
  off-policy actions having been taken via `log_action`. This can be used in
  combination with `get_action`, but will only work, if the connected server
  runs an off-policy RL algorithm (such as DQN, SAC, or DDPG).
"""

import argparse
import gym
import time
from ray.rllib.env.policy_client import PolicyClient

parser = argparse.ArgumentParser()
parser.add_argument(
    "--no-train", action="store_true", help="Whether to disable training."
)
parser.add_argument(
    "--inference-mode", type=str, default="local", choices=["local", "remote"]
)
parser.add_argument(
    "--off-policy",
    action="store_true",
    help="Whether to compute random actions instead of on-policy "
    "(Policy-computed) ones.",
)
parser.add_argument(
    "--stop-reward",
    type=float,
    default=9999,
    help="Stop once the specified reward is reached.",
)
parser.add_argument(
    "--port", type=int, default=9900, help="The port to use (on localhost)."
)

if __name__ == "__main__":
    args = parser.parse_args()

    # The following line is the only instance, where an actual env will
    # be created in this entire example (including the server side!).
    # This is to demonstrate that RLlib does not require you to create
    # unnecessary env objects within the PolicyClient/Server objects, but
    # that only this following env and the loop below runs the entire
    # training process.
    env = gym.make("CartPole-v0")

    # If server has n workers, all ports between 9900 and 990[n-1] should
    # be listened on. E.g. if server has num_workers=2, try 9900 or 9901.
    # Note that no config is needed in this script as it will be defined
    # on and sent from the server.
    # Try connecting to server
    connected = False
    retry = 0
    MAX_CONNECTION_RETRIES = 10
    while not connected and retry < MAX_CONNECTION_RETRIES:
        try:
            client = PolicyClient(
                f"http://localhost:{args.port}", inference_mode=args.inference_mode
            )
            connected = True
        except ConnectionError as e:
            print(f"ConnectionError: {e} (retry {retry} / {MAX_CONNECTION_RETRIES})")
            connected = False
            time.sleep(5)

        retry += 1

    # In the following, we will use our external environment (the CartPole
    # env we created above) in connection with the PolicyClient to query
    # actions (from the server if "remote"; if "local" we'll compute them
    # on this client side), and send back observations and rewards.

    # Start a new episode.
    obs = env.reset()
    eid = client.start_episode(training_enabled=not args.no_train)

    rewards = 0.0
    while True:
        # Compute an action randomly (off-policy) and log it.
        if args.off_policy:
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        # Compute an action locally or remotely (on server).
        # No need to log it here as the action
        else:
            action = client.get_action(eid, obs)

        # Perform a step in the external simulator (env).
        obs, reward, done, info = env.step(action)
        rewards += reward

        # Log next-obs, rewards, and infos.
        client.log_returns(eid, reward, info=info)

        # Reset the episode if done.
        if done:
            print("Total reward:", rewards)
            if rewards >= args.stop_reward:
                print("Target reward achieved, exiting")
                exit(0)

            rewards = 0.0

            # End the old episode.
            client.end_episode(eid, obs)

            # Start a new episode.
            obs = env.reset()
            eid = client.start_episode(training_enabled=not args.no_train)

run in 2 separate terminals

python .\server.py --no-restore --run IMPALA
python .\client.py

NOTE: the issue also appears with --run R2D2, PPO is fine

Anything else

Related issue: #20704

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!
@Fabien-Couthouis Fabien-Couthouis added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Mar 31, 2022
@Fabien-Couthouis Fabien-Couthouis changed the title [Bug][RLlib] "Samples must consistently provide or omit max_seq_len with PolicyServerInput/PolicyClient" [Bug][RLlib] "Samples must consistently provide or omit max_seq_len" with PolicyServerInput/PolicyClient and RNNs Apr 1, 2022
@krfricke krfricke added the rllib RLlib related issues label Apr 1, 2022
@sven1977 sven1977 self-assigned this Apr 4, 2022
@sven1977 sven1977 added rllib-client-server Issue related to RLlib's client/server API. P1 Issue that should be fixed within a few weeks and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Apr 4, 2022
@sven1977
Copy link
Contributor

sven1977 commented Apr 6, 2022

Hey @Fabien-Couthouis , thanks for filing this. I can reproduce this on my end. Trying to provide a fix now ...

@sven1977
Copy link
Contributor

sven1977 commented Apr 6, 2022

Ok, found the issue. Will provide a fix-it PR.

@Fabien-Couthouis
Copy link
Contributor Author

Thanks a lot Sven!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues rllib-client-server Issue related to RLlib's client/server API.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants