Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement model parallelism using PyTorch on an HPC environment? #896

Open
Akshara211 opened this issue Jul 25, 2024 · 16 comments
Open

Comments

@Akshara211
Copy link

Hello,
I am trying to implement model parallelism using PyTorch on my HPC environment, which has 4 GPUs available. My goal is to split a neural network model across these GPUs to improve training efficiency.

Here's what I've tried so far:

Followed the PyTorch documentation on model parallelism
Implemented a basic split of the model across GPUs
However, I am encountering performance bottlenecks and underutilization of the GPUs. Can someone guide me on how to implement this in my HPC setup?

Any advice or pointers to resources would be greatly appreciated!

@qubvel
Copy link
Collaborator

qubvel commented Jul 25, 2024

Hi @Akshara211, I'm not familiar with that topic that much, but I'm curious what is the motivation of doing it? What model configuration are you using?

@Akshara211
Copy link
Author

Hi @qubvel ,
I am currently working on a computer vision project involving a dataset with 32,000 images and their corresponding annotations. While using ResNet-50, I have been able to set the batch size up to 32 without any issues. However, I am encountering difficulties when attempting to train a deeper model, ResNet-152. When using ResNet-152, I am only able to set a batch size of 8. Any attempt to increase the batch size results in a memory error.

Suggestions from various sources recommend implementing model parallelism by distributing the model across different GPUs. How are you training the ResNet-152 model, and what GPU specifications are you using for ResNet-152? I have 4 GPUs available.

@qubvel
Copy link
Collaborator

qubvel commented Jul 26, 2024

I would probably recommend using

  • lower precision. e.g. float16/bfloat16
  • gradient accumulation
  • distributed data parallel (for multi-gpu training)

all these things are really easy to incorporate with Pytorch Lightning, just with a few flags provided to the Trainer class

@Akshara211
Copy link
Author

Thank you for the suggestions!

We are already using Distributed Data Parallel (DDP), so data parallelism is applied. However, we are still facing issues with ResNet152. Our HPC GPU specification is 4x NVIDIA A100-SXM4-40GB. Could you please advise on the specific flags to use for enabling lower precision (float16/bfloat16) and gradient accumulation in PyTorch Lightning?

Is there anything else I could do? Any other suggestions from your side would be appreciated.

Thank you!

@qubvel
Copy link
Collaborator

qubvel commented Jul 29, 2024

@plo97
Copy link

plo97 commented Jul 31, 2024

How do I implement distributed data parallelism (for multi-GPU training)?

@qubvel
Copy link
Collaborator

qubvel commented Jul 31, 2024

@plo97 You can use Pytorch-Lightning for that too
https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html

All you need is:

# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp")

See how to train SMP model with PyTorch-Lightning here:
https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb

@Akshara211
Copy link
Author

Hi @qubvel
When using lower precision (bfloat16) and gradient accumulation, will it take more time to complete one epoch compared to normal training? In my case, the training already took 14 hours and only 7 epochs have been completed.

Additionally, will using these techniques cause any change in accuracy compared to normal training? Here is my trainer configuration:

trainer = pl.Trainer(accelerator='gpu', max_epochs=epochs,callbacks=[checkpoint_callback, early_stopping_callback], devices=4, strategy="ddp",precision="bf16-mixed",accumulate_grad_batches= 4)

@qubvel
Copy link
Collaborator

qubvel commented Aug 23, 2024

Hi @Akshara211 sorry for the late response, I probably missed the notification, were you able to solve your problem?

When using lower precision (bfloat16) and gradient accumulation, will it take more time to complete one epoch compared to normal training?

Regarding bfloat16, it depends on your GPU, for some GPUs you will find it faster, while for others it might be slower. Gradient accumulation should not slow down your training.

Additionally, will using these techniques cause any change in accuracy compared to normal training?

It depends on the careful setup, but in general, it should not change the accuracy. However, you might want to set up sync batchnorm in case you have a small batch size per one GPU (alternatively you can freeze batchnorms).

@Patataman
Copy link

Patataman commented Aug 28, 2024

As far as I know, no library provide automatic tools for model parallelism (FastAI, PyTorch, TensorFlow...), it's up to you to divide and send the appropriate layers to each GPU and synchronize.

For the specific case of smp, I have been able to do both, Data and Model parallelism using the Unet with resnet34 as encoder.
In my case I had 3 nodes with 1 GPU on each (3 GPUs in total). For DP I used PyTorch's DistributedDataParallelism and for MP I used PyTorch's RPC module.

I was also able to test it with 2 GPUs on the same machine, but briefly.

Both launched using torchrun.

About the execution time you mentioned, when doing DP training time should be reduced, for MP it should be similar (if your GPUs are in the same machine) or higher if (as my case) you have a distributed environment due to the overhead

@qubvel
Copy link
Collaborator

qubvel commented Aug 28, 2024

Hi @Patataman!

Thanks a lot for sharing your experience. In case you have time, is it possible to share any code examples how this can be implemented? I would appreciate any details and contributions to make this question more clear for the community, thanks!

@Patataman
Copy link

I think I can arrange a minimal example for DP and MP (this one when the GPUs are on the same node) . On the other hand, distributed MP require more changes and knowledge about the library (in this case RPC), but I can share an example with pseudo code that should help to understand how to do it

@qubvel
Copy link
Collaborator

qubvel commented Aug 28, 2024

Sounds great!

@Patataman
Copy link

Patataman commented Aug 29, 2024

This example would be for data parallelism

import torch
import numpy as np
import segmentation_models_pytorch as smp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP  # For DP
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm

import os
import random
from pathlib import Path


TRAIN_PROB = 1.0
TEST_PROB = 0.0
TRAINSPLIT = 0.8
VALIDSPLIT = 0.2

if __name__ == "__main__":
    rank = int(os.environ.get('RANK',0))
    world_size = int(os.environ.get('WORLD_SIZE',1))
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    """ Preprocess paths
    """
    unet = smp.Unet(
        encoder_name="resnet34",
        in_channels=1,
        classes=1,
        activation="sigmoid"
    )
    unet = DDP(unet)

    lr = 0.001
    batch_size = 4
    n_epochs = 20

    opt = torch.optim.Adam(unet.parameters(), lr=lr)
    loss = smp.losses.MCCLoss()

    # Random image tensors and labels
    n_samples = 100
    image_size = 128
    n_channels = 1
    n_classes = 2

    X = torch.randn(n_samples, n_channels, image_size, image_size)
    y = torch.randint(0, 1, (n_samples, image_size, image_size))

    train_data = list(zip(X,y))

    # https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets
    train_subset, valid_subset = torch.utils.data.random_split(train_data, [TRAINSPLIT,VALIDSPLIT])

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_subset, num_replicas=world_size, rank=rank
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_subset, sampler=train_sampler, batch_size=batch_size
    )
    valid_sampler = SubsetRandomSampler(valid_subset.indices)
    valid_dataloader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, sampler=valid_sampler
    )

    for epoch in tqdm(range(n_epochs), leave=False):
        # Here you should have your train and validation loop
        # I used a custom library, that's why it is not included, hehe

    dist.destroy_process_group()

In the case you want to do MP in the same node (multiple GPUs in the same machine), you can easily do that as

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '52355'
dist.rpc.init_rpc('worker', rank=0, world_size=1)

[...]

unet = smp.Unet(
    encoder_name="resnet34",
    in_channels=1,
    classes=1,
    activation="sigmoid"
)

# I cannot test this right now bc I don't have access to a machine with multiple GPUs, but this should be all the changes
unet.encoder.to("cuda:0")
unet.decoder.to("cuda:1")
unet.segmentation_head.to("cuda:1")

[...]

And finally, MP with multiple nodes. Luckily I have the code for splitting the SMP's Unet for MP using RPC.

import torch
import torch.optim as optim
import torch.distributed.rpc as rpc

import torch.distributed as dist
from torch.distributed.nn import RemoteModule
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions

########################
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from segmentation_models_pytorch.base import SegmentationHead
from segmentation_models_pytorch.base.initialization import initialize_decoder, initialize_head
#####################


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class RPCUnet(torch.nn.Module):
   # This is like "the main model" for training and inference
    def __init__(self, remote_encoder, remote_decoder):
        super().__init__()
        # Define your layers here, for example:
        self.remote_encoder = remote_encoder
        self.remote_decoder = remote_decoder
        kwargs_seghead = {  # from smp source file
            "in_channels": 16,  #kwargs_decoder["decoder_channels"][-1],
            "out_channels": 1,
            "activation": None,
            "kernel_size": 3
        }
        self.segmentation_head = SegmentationHead(**kwargs_seghead).to(DEVICE)
        initialize_head(self.segmentation_head)
        self.sigmoid = torch.nn.Sigmoid().to(DEVICE)

    def forward(self, x):
        x = x.to("cpu")  # RPC only works with cpu
        x = self.remote_encoder(x)
        x = self.remote_decoder(x)
        x = torch.stack(x).to("cuda")
        x = self.segmentation_head(x)
        x = self.sigmoid(x)
        return x

    def get_rref_parameters(self):
        rrefs = self.remote_encoder.remote_parameters()
        rrefs.extend(self.remote_decoder.remote_parameters())
        for param in self.segmentation_head.parameters():
            rrefs.append(RRef(param))
        for param in self.sigmoid.parameters():
            rrefs.append(RRef(param))

        return rrefs
        

class DecoderRPC(torch.nn.Module):
    """ Superclass to initialize Unet decoder in remote worker
    """
    def __init__(self):
        super().__init__()
        self.decoder = UnetDecoder(
            encoder_channels=(3, 64, 64, 128, 256, 512),
            decoder_channels=(256, 128, 64, 32, 16),
            n_blocks=5,
            use_batchnorm=True,
            center=False,
            attention_type=None
        )
        initialize_decoder(self.decoder)

    def forward(self, x):
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=AMP):
            x = [_x.to(DEVICE) for _x in x]
            x = self.decoder(*x)
            
            return x

    def train(self, mode: bool = True):
        self.decoder.to("cpu")
        super().train(mode)
        self.decoder.to("cuda")


class EncoderRPC(torch.nn.Module):
    """ Superclass to initialize Unet decoder in remote worker
    """
    def __init__(self):
        super().__init__()
        self.encoder = get_encoder(
           "resnet34",
            in_channels=1,
            depth=5,  # from smp source file
            weights="imagenet"
        )

    def forward(self, x):
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=AMP):
            x = x.to(DEVICE)
            x = self.encoder(x)
            return x

    def train(self, mode: bool = True):
        self.encoder.to("cpu")
        super().train(mode)
        self.encoder.to("cuda")


def make_modelRPC():
    # Build each Unet section individually to allow to use RemoteModule
    # with them.    
    remote_encoder = RemoteModule(
        "worker1/cuda",
        EncoderRPC,
    )
    remote_decoder = RemoteModule(
        "worker2/cuda",
        DecoderRPC,
    )
    layers = [remote_encoder, remote_decoder]

    rpc_model=RPCUnet(*layers)
    [...]


def init_worker(rank, world_size):
    rpc_backend_options = TensorPipeRpcBackendOptions(
        init_method = f"tcp://{os.environ['MASTER_ADDR']}:52355",
    )
    # Master
    if rank == 0:
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        rpc_model = make_modelRPC()

        # Just load your data and train as when using 1 GPU
        [ ... ]

    elif rank > 0:  # in [1,2]:
        # Initialize RPC.
        worker_name = "worker{}".format(rank)
        rpc.init_rpc(
            worker_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        # Worker just waits for RPCs from master.
    
    rpc.shutdown()
    print(rank, "RPC shutdown.")


if __name__ == "__main__":
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    
    init_worker(rank, world_size)

For MP in a distributed environment you also need to use distributed autograd and optimizer as stated here: https://pytorch.org/docs/stable/rpc.html

@qubvel
Copy link
Collaborator

qubvel commented Aug 29, 2024

@Patataman thanks a lot for taking the time to write these code samples!
I will try to access a machine with multiple GPUs and conduct some experiments, then I will probably wrap everything in a short tutorial in docs (with credits to your snippets). In case you want to collaborate on this you are always welcome!

@Patataman
Copy link

Patataman commented Sep 2, 2024

Btw, I two things that might come handy when testing:

In my case, when running MP with RPC I had to manually set the network interface of each machine putting TP_SOCKET_IFNAME and GLOO_SOCKET_IFNAME environment variables before torchrun, as I mention here (https://discuss.pytorch.org/t/rpc-torchrun-hangs-in-processgroupgloo/197093/2)

The second things is that I think there is no actual need of manually creating each part of the model manually as I did in the code I shared. By the time, that was the only way I managed to get it working, but I think something like this

class rpc_model(torch.nn.Module):
    def __init__(self, unet_part):
        super().__init__()
        self.model = unet_part

    def forward(self, x):
        x = x.to("cuda")
        return self.model(x)

unet = smp.Unet()

part1 = rpc_model(unet.encoder)
part2 = rpc_model(unet.decoder)
part3 = rpc_model(unet.segmentation_head)

I think that should work. I don't remember now why it didn't work for me by the time, but I have been able to do similar things with other models. The only limit for RPC (afaik) it's that the data must be pickleable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants