Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] update pipeline parallel document #4725

Merged
merged 8 commits into from
Sep 15, 2023
217 changes: 136 additions & 81 deletions docs/source/en/features/pipeline_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
Author: Guangyang Lu, Hongxin Liu, Yongbin Li
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved

**Prerequisite**
- [Define Your Configuration](../basics/define_your_config.md)
- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
- [Configure Parallelization](../basics/configure_parallelization.md)
- [Use Booster to Training](../basics/booster_api.md)
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved
- [Shardformer](../features/shardformer.md)

**Example Code**
- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
- [Fine-tune Bert with pipeline](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/bert/finetune.py)

**Related Paper**
- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
Expand All @@ -17,15 +16,15 @@ Author: Guangyang Lu, Hongxin Liu, Yongbin Li

## Quick introduction

In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example.
In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use bert model and glue dataset as example.

## Table Of Content

In this tutorial we will cover:

1. Introduction of 1F1B pipeline.
2. Usage of non-interleaved and interleaved schedule.
3. Training ResNet with pipeline.
3. Finetune Bert with pipeline.

## Introduction of 1F1B pipeline

Expand Down Expand Up @@ -60,101 +59,157 @@ In this schedule, each device can perform computation for multiple subsets of la

This mode is both memory-efficient and time-efficient.

## Usage of non-interleaved and interleaved schedule
## Colossal-AI's Implement
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved

In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`).
In Colossal-AI, pipeline parallelism relies on the `scheduler` and `Shardformer`. We provide both non-interleaved (`OneForwardOneBackwardSchedule`) and interleaved (`InterleavedSchedule`) schedules. While `Shardformer` implements layer splitting for models and replaces the `forward` function of the model to make it compatible with the scheduler.
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved

You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you.
In Colossal-AI, the `HybridParallelPlugin` encapsulates pipeline execution strategies. It manages pipeline parallel communication groups and a scheduler. When boosting the model with this plugin, the model's layers are split by calling the `shardformer.optimize` function, and then `execute_pipeline` is called to execute the model in segments using either `OneForwardOneBackwardSchedule` or `InterleavedSchedule`.
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved

## Training ResNet with pipeline
You can customize your parallel strategy by setting parameters for the HybridParallelPlugin.

Let's build the `ResNet` model first with Colossal PipelinableContext:
## Fine-tune Bert with pipeline

First, we define the necessary training components, including model, dataloader, optimizer, lr_scheduler, criterion:
```python
import os
from typing import Callable, List, Optional, Type, Union
import argparse
from contextlib import nullcontext
from typing import Callable, List, Union

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AlbertForSequenceClassification,
AutoConfig,
BertForSequenceClassification,
get_linear_schedule_with_warmup,
)

import colossalai
import colossalai.nn as col_nn
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# Define some config
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}


# Define 'criterion' function
def _criterion(outputs, inputs):
return outputs.loss

# Define optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": WEIGHT_DECAY,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]

from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.legacy.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1

# Define some config
BATCH_SIZE = 64
NUM_EPOCHS = 2
NUM_CHUNKS = 1
CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))

# Train
disable_existing_loggers()
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
logger = get_dist_logger()
pipelinable = PipelinableContext()

# build model
with pipelinable:
model = resnet50()
```
# Define lr_scheduler
total_steps = len(train_dataloader) * NUM_EPOCHS
num_warmup_steps = int(WARMUP_FRACTION * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)

Define an execution sequence.
```python
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
(lambda x: torch.flatten(x, 1), "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)

# Define Bert model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=cfg).cuda()

# Define a dataloader
data_builder = GLUEDataBuilder(model_name,
plugin,
args.task,
train_batch_size=BATCH_SIZE,
eval_batch_size=BATCH_SIZE)
train_dataloader = data_builder.train_dataloader()
```

Partition the model into pipeline.
Define a booster with the 'HybridParallelPlugin'.
```python
model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)
booster = Booster(plugin=plugin)
```

In this tutorial, we use `Trainer` to train `ResNet`:
Boost these train componts with the booster created.
```python
# build criterion
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# build dataloader
root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)

lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
train_dataloader, test_dataloader,
lr_scheduler)
timer = MultiTimer()

trainer = Trainer(engine=engine, timer=timer, logger=logger)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)
```

hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(col_nn.metric.Accuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
]
Train the model at last.

trainer.fit(train_dataloader=train_dataloader,
epochs=NUM_EPOCHS,
test_dataloader=test_dataloader,
test_interval=1,
hooks=hook_list,
display_progress=True)
```python
# Define a train function
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):

is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
total_step = len(train_dataloader)

model.train()
optimizer.zero_grad()
train_dataloader_iter = iter(train_dataloader)
with tqdm(range(total_step),
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (is_pp_last_stage)) as pbar:
# Forward pass
for _ in pbar:
outputs = booster.execute_pipeline(train_dataloader_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if is_pp_last_stage:
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})

optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

# Train model
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
```

We use `2` pipeline stages and the batch will be split into `4` micro batches.
We use `2` pipeline stages and the micro batches is 1. (these parameters can be configured to an appropriate value)
<!-- doc-test-command: echo -->
Loading
Loading