Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Deepspeed integration #4693

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e2ac4b5
first draft of deepspeed trainer
jacobdanovitch Oct 2, 2020
619657e
delegating grad_clipping, grad_norm, grad_acculumation, etc. to deeps…
jacobdanovitch Oct 2, 2020
a329fd2
cleaning up deepspeed config interface
jacobdanovitch Oct 2, 2020
00666c2
idenifying bottleneck / start simplifying model engine
jacobdanovitch Oct 7, 2020
f0da3bf
1416 LOC -> 562
jacobdanovitch Oct 10, 2020
d0e8a68
debugging memory leak
jacobdanovitch Oct 28, 2020
0a74573
functioning / cleaner prototype
jacobdanovitch Oct 31, 2020
eaf8aa5
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 2, 2020
498d3a2
checkpointing works e2e
jacobdanovitch Nov 2, 2020
a211b5e
ready for review
jacobdanovitch Nov 5, 2020
3b30e21
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 5, 2020
fdd888b
add new trainer/lazy changes
jacobdanovitch Nov 5, 2020
ef544c9
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 9, 2020
083a6d0
dangling changes
jacobdanovitch Nov 23, 2020
0f8d5b7
Merge branch 'master' of https://github.com/allenai/allennlp into jac…
jacobdanovitch Nov 23, 2020
4e4f7d7
updating from master
jacobdanovitch Nov 30, 2020
f48ea19
typechecks passing!
jacobdanovitch Nov 30, 2020
b3328fc
init file
jacobdanovitch Jan 3, 2021
966e296
Merge remote-tracking branch 'upstream/main' into jacobdanovitch/deep…
jacobdanovitch Jan 3, 2021
2fdb7c0
save old tests in case
jacobdanovitch Jan 8, 2021
95a9e5f
tracking down dist barrier bug(s)
jacobdanovitch Jan 8, 2021
b152fe1
catch up
jacobdanovitch Jan 19, 2021
5b82534
Merge branch 'main' of https://github.com/allenai/allennlp into jacob…
jacobdanovitch Jan 19, 2021
4fb6604
moved master checks to checkpointer to accomodate deepspeed
jacobdanovitch Jan 20, 2021
e21fb1f
Merge branch 'main' of https://github.com/allenai/allennlp into jacob…
jacobdanovitch Feb 10, 2021
703843c
updating to 2.0
jacobdanovitch Feb 10, 2021
e7b8825
checking in sparse attention
jacobdanovitch Feb 18, 2021
3fc1835
merge resolution
jacobdanovitch Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions allennlp/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
EpochCallback,
TrackEpochCallback,
)
from allennlp.training.deepspeed import DeepspeedTrainer

# import warnings
# try:
# from allennlp.training.deepspeed import DeepspeedTrainer
# except ImportError:
# warnings.warn('Deepspeed plugin not installed. Ignoring.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this leftover debug code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on how we want to include this. Based on my experience, I wouldn't recommend making deepspeed a required dependency. If we're doing the pip install allennlp[deepspeed] thing, this could be replaced/updated (not sure offhand how that gets handled but I can look for some examples).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind doing the work making it optional, then let's make it optional.

11 changes: 11 additions & 0 deletions allennlp/training/deepspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from allennlp.training.deepspeed.trainer import DeepspeedTrainer
jacobdanovitch marked this conversation as resolved.
Show resolved Hide resolved
from allennlp.training.deepspeed.optimizers import (
FusedAdamOptimizer,
DeepspeedCPUAdamOptimizer,
FusedLambOptimizer
)

try:
from allennlp.training.deepspeed.sparse_transformer_embedder import SparseTransformerEmbedder
except ImportError:
pass
62 changes: 62 additions & 0 deletions allennlp/training/deepspeed/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Dict, Any
from enum import IntEnum
from allennlp.common import FromParams
from dataclasses import dataclass, asdict


@dataclass
class DeepspeedFP16Config(FromParams):
enabled: bool = True
loss_scale: float = 0.
initial_scale_power: int = 32
loss_scale_window: int = 1000
hysteresis: int = 2
min_loss_scale: float = 1.

@dataclass
class DeepspeedAMPConfig(FromParams):
enabled: bool = False
opt_level: str = "O1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought AMP was dead and we now use things built directly into PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but it's a required install for deepspeed and you can use it there, so I thought I would keep it in for compatibility. It can be removed if need be.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Surely in the next DeepSpeed version they will make it use PyTorch-native AMP. But if we need it for now, that's cool.


@dataclass
class DeepspeedOptimizerConfig(FromParams):
type: str
params: Dict[str, Any]

class DeepspeedZeROStage(IntEnum):
DISABLED = 0
OPTIMIZER = 1
GRADIENT = 2

@dataclass
class DeepspeedZeROConfig(FromParams):
stage: DeepspeedZeROStage = DeepspeedZeROStage.GRADIENT
allgather_partitions: bool = True
allgather_bucket_size: int = 500000000
overlap_comm: bool = False
reduce_scatter: bool = True
reduce_bucket_size: int = 500000000
contiguous_gradients: bool = False
cpu_offload: bool = False


@dataclass
class DeepspeedConfig(FromParams):
zero_optimization: DeepspeedZeROConfig
fp16: DeepspeedFP16Config
amp: DeepspeedAMPConfig = DeepspeedAMPConfig()
optimizer: DeepspeedOptimizerConfig = None

zero_allow_untested_optimizer: bool = True
wall_clock_breakdown: bool = False

def to_dict(self):
return asdict(self)


@dataclass
class DeepspeedArgs(FromParams):
local_rank: int
deepspeed: bool = True
deepspeed_mpi: bool = False
deepspeed_config: str = None
87 changes: 87 additions & 0 deletions allennlp/training/deepspeed/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import List, Tuple, Dict, Any

import torch

from apex.optimizers.fused_adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.lamb import FusedLamb
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam

from allennlp.training.optimizers import Optimizer, make_parameter_groups

@Optimizer.register("fused_adam")
class FusedAdamOptimizer(Optimizer, FusedAdam):
def __init__(
self,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None,
lr: float = 0.001,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0.0,
amsgrad: bool = False,
bias_correction: bool =True,
adam_w_mode: bool = True,
set_grad_none: bool = True,
):
super().__init__(
params=make_parameter_groups(model_parameters, parameter_groups),
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
bias_correction=bias_correction,
adam_w_mode=adam_w_mode,
set_grad_none=set_grad_none,
)

# This does not currently work
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? If it doesn't work and is not necessary, can we remove it?

@Optimizer.register("cpu_adam")
class DeepspeedCPUAdamOptimizer(Optimizer, DeepSpeedCPUAdam):
def __init__(
self,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None,
lr: float = 0.001,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0.0,
amsgrad: bool = False,
):
super().__init__(
model_params=make_parameter_groups(model_parameters, parameter_groups),
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad
)

@Optimizer.register("fused_lamb")
class FusedLambOptimizer(Optimizer, FusedLamb):
def __init__(
self,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
parameter_groups: List[Tuple[List[str], Dict[str, Any]]] = None,
lr: float = 0.001,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
eps_inside_sqrt: bool = False,
weight_decay: float = 0.0,
amsgrad: bool = False,
max_grad_norm: float = 0.,
max_coeff: float = 10.0,
min_coeff: float = 0.01
):
super().__init__(
params=make_parameter_groups(model_parameters, parameter_groups),
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
max_grad_norm=max_grad_norm,
max_coeff=max_coeff,
min_coeff=min_coeff,
)
10 changes: 10 additions & 0 deletions allennlp/training/deepspeed/sparse_transformer_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder

from deepspeed.ops.sparse_attention.sparse_attention_utils import SparseAttentionUtils

@TokenEmbedder.register('sparse_transformer')
class SparseTransformerEmbedder(PretrainedTransformerEmbedder):
class __init__(self, **kwargs):
super().__init__(**kwargs)
self.transformer_model = SparseAttentionUtils.replace_model_self_attention_with_sparse_self_attention(self.transformer_model)
Loading