Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] train_fashion_mnist_example fails with 1 worker #19506

Closed
2 tasks done
Junphy-Jan opened this issue Oct 19, 2021 · 8 comments · Fixed by #19518
Closed
2 tasks done

[Train] train_fashion_mnist_example fails with 1 worker #19506

Junphy-Jan opened this issue Oct 19, 2021 · 8 comments · Fixed by #19518
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks

Comments

@Junphy-Jan
Copy link

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

Others

What happened + What you expected to happen

I'm new to ray and learning the example of raysgd here:
RaySGD-examples-train_fashion_mnist_example

But with the default parameter "num-workers=1", I got a error:

Traceback (most recent call last): File "...ray_example/raysgd_train_mnist.py", line 173, in <module> train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu) File "...ray_example/raysgd_train_mnist.py", line 127, in train_fashion_mnist result = trainer.run( File ".../site-packages/ray/util/sgd/v2/trainer.py", line 240, in run for intermediate_result in iterator: File ".../site-packages/ray/util/sgd/v2/trainer.py", line 567, in __next__ self._run_with_error_handling( File ".../site-packages/ray/util/sgd/v2/trainer.py", line 537, in _run_with_error_handling return func() File ".../site-packages/ray/util/sgd/v2/backends/backend.py", line 600, in finish_training results = self.get_with_failure_handling(futures) File ".../site-packages/ray/util/sgd/v2/backends/backend.py", line 619, in get_with_failure_handling success, failed_worker_indexes = check_for_failure(remote_values) File ".../site-packages/ray/util/sgd/v2/utils.py", line

But if set --num-workers:2, it can work well.
Can someone help me figure out how it works.

Versions / Dependencies

ray 1.7
Python 3.8
Win10 1909
Wsl 1

Reproduction script

Just the official example

Anything else

No response

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!
@Junphy-Jan Junphy-Jan added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Oct 19, 2021
@Junphy-Jan
Copy link
Author

The error information seems not complete, here is the whole info:
Traceback (most recent call last):
File "...ray_example/raysgd_train_mnist.py", line 173, in
train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)
File "...ray_example/raysgd_train_mnist.py", line 127, in train_fashion_mnist
result = trainer.run(
File ".../site-packages/ray/util/sgd/v2/trainer.py", line 240, in run
for intermediate_result in iterator:
File ".../site-packages/ray/util/sgd/v2/trainer.py", line 567, in next
self._run_with_error_handling(
File ".../site-packages/ray/util/sgd/v2/trainer.py", line 537, in _run_with_error_handling
return func()
File ".../site-packages/ray/util/sgd/v2/backends/backend.py", line 600, in finish_training
results = self.get_with_failure_handling(futures)
File ".../site-packages/ray/util/sgd/v2/backends/backend.py", line 619, in get_with_failure_handling
success, failed_worker_indexes = check_for_failure(remote_values)
File ".../site-packages/ray/util/sgd/v2/utils.py", line 34, in check_for_failure
ray.get(object_ref)
File ".../site-packages/ray/_private/client_mode_hook.py", line 89, in wrapper
return func(*args, **kwargs)
File ".../site-packages/ray/worker.py", line 1621, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): ray::BaseWorkerMixin._BaseWorkerMixin__execute() (pid=8455, ip=xxx, repr=<ray.util.sgd.v2.worker_group.BaseWorkerMixin object at 0x7fbec132efd0>)
File ".../site-packages/ray/util/sgd/v2/worker_group.py", line 25, in __execute
return func(*args, **kwargs)
File ".../site-packages/ray/util/sgd/v2/backends/backend.py", line 574, in end_training
output = session.finish()
File ".../site-packages/ray/util/sgd/v2/session.py", line 84, in finish
func_output = self.training_thread.join()
File ".../site-packages/ray/util/sgd/v2/utils.py", line 86, in join
raise self.exc
File ".../site-packages/ray/util/sgd/v2/utils.py", line 79, in run
self.ret = self._target(*self._args, **self._kwargs)
File ".../site-packages/ray/util/sgd/v2/trainer.py", line 337, in
return lambda: train_func(config)
File "...ray_example/raysgd_train_mnist.py", line 96, in train_func
sampler=DistributedSampler(training_data))
File ".../site-packages/torch/utils/data/distributed.py", line 54, in init
num_replicas = dist.get_world_size()
File ".../site-packages/torch/distributed/distributed_c10d.py", line 621, in get_world_size
return _get_group_size(group)
File ".../site-packages/torch/distributed/distributed_c10d.py", line 219, in _get_group_size
_check_default_pg()
File ".../site-packages/torch/distributed/distributed_c10d.py", line 209, in _check_default_pg
assert _default_pg is not None,
AssertionError: Default process group is not initialized

@amogkam amogkam added P1 Issue that should be fixed within a few weeks and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Oct 19, 2021
@amogkam amogkam changed the title [Bug] an official example doesn't work [Train] train_fashion_mnist_example fails with 1 worker Oct 19, 2021
@matthewdeng
Copy link
Contributor

Hey @JanJF, thanks for letting us know about this issue!

The reason for the behavior you're seeing is indeed that the training_function code is trying to execute distributed training (with DistributedSampler) when the Torch distributed process group isn't set up (because num_workers=1). I've updated the example to use num_workers=2 in #19518.

Also just a small heads up we've rebranded RaySGD to Ray Train, so you may notice some changes in the docs/package structure!

@Junphy-Jan
Copy link
Author

I can infer the reason that the process group is not initiated. But I'm confused where is the process group initiated when the num_works=2.

@matthewdeng
Copy link
Contributor

This is handled within the Ray Train library code, in TorchBackend. This is done "implicitly", but if the end-user needs additional configuration they can pass in a TorchConfig to the Trainer constructor.

@amogkam
Copy link
Contributor

amogkam commented Oct 26, 2021

Note that this is only started if num_workers is >1 https://github.com/ray-project/ray/blob/master/python/ray/train/backends/torch.py#L98, hence why the error is happening with 1 worker.

@Junphy-Jan
Copy link
Author

Note that this is only started if num_workers is >1 https://github.com/ray-project/ray/blob/master/python/ray/train/backends/torch.py#L98, hence why the error is happening with 1 worker.

Thanks! I tried several times and find out it will initial when running the trainer.run()
After run the examples, I got another question that the raysgd can train my model on two processor and I saw the result that two models have same weight.
But compare to regular program(non-paraller program), it use half data each training process. I find that if I use non-paraller program to train the same model, I can get lower loss value than use raysgd with same configuration. Concretely, if I run 5 epochs with non-paraller program I got loss of 0.7 while using raysgd with num_workers=2, I got 3.3 loss each worker finally.
I wonder if I have misunderstood with the raysgd.
It will be very nice if you can explain to me.

@Junphy-Jan
Copy link
Author

Here is my code if it helps.

import argparse
from typing import List, Dict

import numpy as np
import torch
import torch.nn as nn
import ray.util.sgd.v2 as sgd
from ray.util.sgd.v2 import Trainer, TorchConfig
from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback, SGDCallback
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, SequentialSampler

config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 32, "epochs": 5}


class LinearDataset(torch.utils.data.Dataset):
    """y = a * x + b"""

    def __init__(self, a, b, size=1000):
        x = np.arange(0, 10, 10 / size, dtype=np.float32)
        self.x = torch.from_numpy(x)
        self.y = torch.from_numpy(a * x + b)

    def __getitem__(self, index):
        return self.x[index, None], self.y[index, None]

    def __len__(self):
        return len(self.x)


def train(dataloader, model, loss_fn, optimizer):
    loss_total = 0
    for X, y in dataloader:
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        loss_total += loss.cpu().detach().item()
        optimizer.step()
    result = {"period": "train", "model": model.state_dict(), "dataloader_len": len(dataloader),
              "avg_loss_per_epoch": loss_total / len(dataloader)}
    return result


def validate(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            loss += loss_fn(pred, y).item()
    loss /= num_batches
    result = {"model": model.state_dict(), "loss": loss}
    return result


def train_func(config):
    data_size = config.get("data_size", 1000)
    val_size = config.get("val_size", 400)
    batch_size = config.get("batch_size", 32)
    hidden_size = config.get("hidden_size", 1)
    lr = config.get("lr", 1e-2)
    epochs = config.get("epochs", 5)

    train_dataset = LinearDataset(2, 5, size=data_size)
    val_dataset = LinearDataset(2, 5, size=val_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=DistributedSampler(train_dataset))
    validation_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        sampler=DistributedSampler(val_dataset))

    model = nn.Linear(1, hidden_size)
    model = DistributedDataParallel(model)

    loss_fn = nn.MSELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    results = []

    for _ in range(epochs):
        train_ret = train(train_loader, model, loss_fn, optimizer)
        # sgd.report(**train_ret)
        result = validate(validation_loader, model, loss_fn)
        sgd.report(**train_ret)
        results.append(train_ret)
    # results.append()
    return results


def train4local(config_local):
    data_size = config_local.get("data_size", 1000)
    val_size = config_local.get("val_size", 400)
    batch_size = config_local.get("batch_size", 32)
    hidden_size = config_local.get("hidden_size", 1)
    lr = config_local.get("lr", 1e-2)
    epochs = config_local.get("epochs", 5)

    train_dataset = LinearDataset(2, 5, size=data_size)
    val_dataset = LinearDataset(2, 5, size=val_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=SequentialSampler(train_dataset))
    validation_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        sampler=SequentialSampler(val_dataset))

    model = nn.Linear(1, hidden_size)

    loss_fn = nn.MSELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    results = []

    for _ in range(epochs):
        train_ret = train(train_loader, model, loss_fn, optimizer)
        # sgd.report(**train_ret)
        result = validate(validation_loader, model, loss_fn)

        results.append(train_ret)
    # results.append()
    return results


def train_linear(num_workers=2):
    trainer = Trainer(TorchConfig(backend="gloo"), num_workers=num_workers)
    trainer.start()
    results = trainer.run(
        train_func,
        config,
        callbacks=[JsonLoggerCallback(workers_to_log=None),
                   TBXLoggerCallback()])
    trainer.shutdown()

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address",
        required=False,
        type=str,
        help="the address to use for Ray")
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=2,
        help="Sets number of workers for training.")
    parser.add_argument(
        "--smoke-test",
        action="store_true",
        default=False,
        help="Finish quickly for testing.")

    args, _ = parser.parse_known_args()
    if args.num_workers > 1:
        import ray

        if args.smoke_test:
            ray.init()
        else:
            ray.init(address='auto', _redis_password='5241590000000000')
        ret = train_linear(num_workers=args.num_workers)
    else:
        print(train4local(config))

@amogkam
Copy link
Contributor

amogkam commented Oct 26, 2021

Hey @JanJF, I created a new issue for the follow up question. We can move the discussion there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants