Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Deepspeed integration #4693

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e2ac4b5
first draft of deepspeed trainer
jacobdanovitch Oct 2, 2020
619657e
delegating grad_clipping, grad_norm, grad_acculumation, etc. to deeps…
jacobdanovitch Oct 2, 2020
a329fd2
cleaning up deepspeed config interface
jacobdanovitch Oct 2, 2020
00666c2
idenifying bottleneck / start simplifying model engine
jacobdanovitch Oct 7, 2020
f0da3bf
1416 LOC -> 562
jacobdanovitch Oct 10, 2020
d0e8a68
debugging memory leak
jacobdanovitch Oct 28, 2020
0a74573
functioning / cleaner prototype
jacobdanovitch Oct 31, 2020
eaf8aa5
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 2, 2020
498d3a2
checkpointing works e2e
jacobdanovitch Nov 2, 2020
a211b5e
ready for review
jacobdanovitch Nov 5, 2020
3b30e21
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 5, 2020
fdd888b
add new trainer/lazy changes
jacobdanovitch Nov 5, 2020
ef544c9
Merge branch 'master' into jacobdanovitch/deepspeed
jacobdanovitch Nov 9, 2020
083a6d0
dangling changes
jacobdanovitch Nov 23, 2020
0f8d5b7
Merge branch 'master' of https://github.com/allenai/allennlp into jac…
jacobdanovitch Nov 23, 2020
4e4f7d7
updating from master
jacobdanovitch Nov 30, 2020
f48ea19
typechecks passing!
jacobdanovitch Nov 30, 2020
b3328fc
init file
jacobdanovitch Jan 3, 2021
966e296
Merge remote-tracking branch 'upstream/main' into jacobdanovitch/deep…
jacobdanovitch Jan 3, 2021
2fdb7c0
save old tests in case
jacobdanovitch Jan 8, 2021
95a9e5f
tracking down dist barrier bug(s)
jacobdanovitch Jan 8, 2021
b152fe1
catch up
jacobdanovitch Jan 19, 2021
5b82534
Merge branch 'main' of https://github.com/allenai/allennlp into jacob…
jacobdanovitch Jan 19, 2021
4fb6604
moved master checks to checkpointer to accomodate deepspeed
jacobdanovitch Jan 20, 2021
e21fb1f
Merge branch 'main' of https://github.com/allenai/allennlp into jacob…
jacobdanovitch Feb 10, 2021
703843c
updating to 2.0
jacobdanovitch Feb 10, 2021
e7b8825
checking in sparse attention
jacobdanovitch Feb 18, 2021
3fc1835
merge resolution
jacobdanovitch Feb 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Added `DeepspeedTrainer` and `FusedLambOptimizer`.

### Fixed

- Ensured that `MeanAbsoluteError` always returns a `float` metric value instead of a `Tensor`.
Expand Down
4 changes: 3 additions & 1 deletion allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def add_argument(self, *args, **kwargs):
super().add_argument(*args, **kwargs)


def parse_args(prog: Optional[str] = None) -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
def parse_args(
prog: Optional[str] = None,
) -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
"""
Creates the argument parser for the main program and uses it to parse the args.
"""
Expand Down
20 changes: 15 additions & 5 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,23 @@ class Evaluate(Subcommand):
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Evaluate the specified model + dataset"""
subparser = parser.add_parser(
self.name, description=description, help="Evaluate the specified model + dataset."
self.name,
description=description,
help="Evaluate the specified model + dataset.",
)

subparser.add_argument("archive_file", type=str, help="path to an archived trained model")

subparser.add_argument(
"input_file", type=str, help="path to the file containing the evaluation data"
"input_file",
type=str,
help="path to the file containing the evaluation data",
)

subparser.add_argument(
"--output-file", type=str, help="optional path to write the metrics to as JSON"
"--output-file",
type=str,
help="optional path to write the metrics to as JSON",
)

subparser.add_argument(
Expand All @@ -47,7 +53,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
)

subparser.add_argument(
"--weights-file", type=str, help="a path that overrides which weights file to use"
"--weights-file",
type=str,
help="a path that overrides which weights file to use",
)

cuda_device = subparser.add_mutually_exclusive_group(required=False)
Expand All @@ -68,7 +76,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
)

subparser.add_argument(
"--batch-size", type=int, help="If non-empty, the batch size to use during evaluation."
"--batch-size",
type=int,
help="If non-empty, the batch size to use during evaluation.",
)

subparser.add_argument(
Expand Down
14 changes: 11 additions & 3 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
)

subparser.add_argument(
"param_path", type=str, help="path to parameter file describing the model to be trained"
"param_path",
type=str,
help="path to parameter file describing the model to be trained",
)
subparser.add_argument(
"-s",
Expand All @@ -60,10 +62,16 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
),
)
subparser.add_argument(
"--start-lr", type=float, default=1e-5, help="learning rate to start the search"
"--start-lr",
type=float,
default=1e-5,
help="learning rate to start the search",
)
subparser.add_argument(
"--end-lr", type=float, default=10, help="learning rate up to which search is done"
"--end-lr",
type=float,
default=10,
help="learning rate up to which search is done",
)
subparser.add_argument(
"--num-batches",
Expand Down
17 changes: 13 additions & 4 deletions allennlp/commands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument

description = """Run the specified model against a JSON-lines input file."""
subparser = parser.add_parser(
self.name, description=description, help="Use a trained model to make predictions."
self.name,
description=description,
help="Use a trained model to make predictions.",
)

subparser.add_argument(
Expand All @@ -38,12 +40,17 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument

subparser.add_argument("--output-file", type=str, help="path to output file")
subparser.add_argument(
"--weights-file", type=str, help="a path that overrides which weights file to use"
"--weights-file",
type=str,
help="a path that overrides which weights file to use",
)

batch_size = subparser.add_mutually_exclusive_group(required=False)
batch_size.add_argument(
"--batch-size", type=int, default=1, help="The batch size to use for processing"
"--batch-size",
type=int,
default=1,
help="The batch size to use for processing",
)

subparser.add_argument(
Expand Down Expand Up @@ -86,7 +93,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
)

subparser.add_argument(
"--predictor", type=str, help="optionally specify a specific predictor to use"
"--predictor",
type=str,
help="optionally specify a specific predictor to use",
)

subparser.add_argument(
Expand Down
5 changes: 4 additions & 1 deletion allennlp/commands/subcommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
@classmethod
@overrides
def register(
cls: Type[T], name: str, constructor: Optional[str] = None, exist_ok: bool = False
cls: Type[T],
name: str,
constructor: Optional[str] = None,
exist_ok: bool = False,
) -> Callable[[Type[T]], Type[T]]:
super_register_fn = super().register(name, constructor=constructor, exist_ok=exist_ok)

Expand Down
15 changes: 12 additions & 3 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from allennlp.common.plugins import import_plugins
from allennlp.data import DatasetReader, Vocabulary
from allennlp.data import DataLoader
from allennlp.models.archival import archive_model, CONFIG_NAME, verify_include_in_archive
from allennlp.models.archival import (
archive_model,
CONFIG_NAME,
verify_include_in_archive,
)
from allennlp.models.model import _DEFAULT_WEIGHTS, Model
from allennlp.training.trainer import Trainer
from allennlp.training import util as training_util
Expand All @@ -40,7 +44,9 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
subparser = parser.add_parser(self.name, description=description, help="Train a model.")

subparser.add_argument(
"param_path", type=str, help="path to parameter file describing the model to be trained"
"param_path",
type=str,
help="path to parameter file describing the model to be trained",
)

subparser.add_argument(
Expand Down Expand Up @@ -80,7 +86,10 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
)

subparser.add_argument(
"--node-rank", type=int, default=0, help="rank of this node in the distributed setup"
"--node-rank",
type=int,
default=0,
help="rank of this node in the distributed setup",
)

subparser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion allennlp/common/cached_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def strip_prefix(s):
}
if len(valid_keys) > 0:
logger.info(
"Loading %d tensors from %s", len(valid_keys), override_weights_file
"Loading %d tensors from %s",
len(valid_keys),
override_weights_file,
)
else:
raise ValueError(
Expand Down
14 changes: 11 additions & 3 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,10 @@ class CacheFile:
"""

def __init__(
self, cache_filename: Union[PathLike, str], mode: str = "w+b", suffix: str = ".tmp"
self,
cache_filename: Union[PathLike, str],
mode: str = "w+b",
suffix: str = ".tmp",
) -> None:
self.cache_filename = (
cache_filename if isinstance(cache_filename, Path) else Path(cache_filename)
Expand All @@ -671,7 +674,9 @@ def __exit__(self, exc_type, exc_value, traceback):
if exc_value is None:
# Success.
logger.debug(
"Renaming temp file %s to cache at %s", self.temp_file.name, self.cache_filename
"Renaming temp file %s to cache at %s",
self.temp_file.name,
self.cache_filename,
)
# Rename the temp file to the actual cache filename.
os.replace(self.temp_file.name, self.cache_filename)
Expand Down Expand Up @@ -922,7 +927,10 @@ def get_file_extension(path: str, dot=True, lower: bool = True):


def open_compressed(
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs
filename: Union[str, PathLike],
mode: str = "rt",
encoding: Optional[str] = "UTF-8",
**kwargs,
):
if not isinstance(filename, str):
filename = str(filename)
Expand Down
24 changes: 19 additions & 5 deletions allennlp/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def is_base_registrable(cls) -> bool:
Checks whether this is a class that directly inherits from Registrable, or is a subclass of such
a class.
"""
from allennlp.common.registrable import Registrable # import here to avoid circular imports
from allennlp.common.registrable import (
Registrable,
) # import here to avoid circular imports

if not issubclass(cls, Registrable):
return False
Expand Down Expand Up @@ -148,7 +150,10 @@ def infer_params(
else:
super_parameters = {}

return {**super_parameters, **parameters} # Subclass parameters overwrite superclass ones
return {
**super_parameters,
**parameters,
} # Subclass parameters overwrite superclass ones


def create_kwargs(
Expand Down Expand Up @@ -245,7 +250,12 @@ def create_extras(cls: Type[T], extras: Dict[str, Any]) -> Dict[str, Any]:


def pop_and_construct_arg(
class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, **extras
class_name: str,
argument_name: str,
annotation: Type,
default: Any,
params: Params,
**extras,
) -> Any:
"""
Does the work of actually constructing an individual argument for
Expand All @@ -261,7 +271,9 @@ def pop_and_construct_arg(
`inspect.Parameter` object directly, so that we can handle `Union` types using recursion on
this method, trying the different annotation types in the union in turn.
"""
from allennlp.models.archival import load_archive # import here to avoid circular imports
from allennlp.models.archival import (
load_archive,
) # import here to avoid circular imports

# We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in
# `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back
Expand Down Expand Up @@ -536,7 +548,9 @@ def from_params(
constructor (because you inspect `__init__`, but call `cls()`).
"""

from allennlp.common.registrable import Registrable # import here to avoid circular imports
from allennlp.common.registrable import (
Registrable,
) # import here to avoid circular imports

logger.debug(
f"instantiating class {cls} from params {getattr(params, 'params', params)} "
Expand Down
11 changes: 9 additions & 2 deletions allennlp/common/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,21 @@
The global plugins file will be found here.
"""

DEFAULT_PLUGINS = ("allennlp_models", "allennlp_semparse", "allennlp_server")
DEFAULT_PLUGINS = (
"allennlp_models",
"allennlp_semparse",
"allennlp_server",
"allennlp.training.deepspeed",
)
"""
Default plugins do not need to be declared in a plugins file. They will always
be imported when they are installed in the current Python environment.
"""


def discover_file_plugins(plugins_filename: str = LOCAL_PLUGINS_FILENAME) -> Iterable[str]:
def discover_file_plugins(
plugins_filename: str = LOCAL_PLUGINS_FILENAME,
) -> Iterable[str]:
"""
Returns an iterable of the plugins found, declared within a file whose path is `plugins_filename`.
"""
Expand Down
22 changes: 17 additions & 5 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def set_up_model(
self.vocab = vocab
self.instances = instances
self.model = Model.from_params(
vocab=self.vocab, params=params["model"], serialization_dir=serialization_dir
vocab=self.vocab,
params=params["model"],
serialization_dir=serialization_dir,
)

# TODO(joelgrus) get rid of these
Expand Down Expand Up @@ -149,13 +151,17 @@ def ensure_model_can_train_save_and_load(

print("Reading with original model")
data_loader = DataLoader.from_params(
params=data_loader_params, reader=reader, data_path=params["validation_data_path"]
params=data_loader_params,
reader=reader,
data_path=params["validation_data_path"],
)
data_loader.index_with(model.vocab)

print("Reading with loaded model")
data_loader2 = DataLoader.from_params(
params=data_loader_params2, reader=reader, data_path=params["validation_data_path"]
params=data_loader_params2,
reader=reader,
data_path=params["validation_data_path"],
)
data_loader2.index_with(loaded_model.vocab)

Expand Down Expand Up @@ -193,7 +199,10 @@ def ensure_model_can_train_save_and_load(
# Both outputs should have the same keys and the values for these keys should be close.
for key in model_predictions.keys():
self.assert_fields_equal(
model_predictions[key], loaded_model_predictions[key], name=key, tolerance=tolerance
model_predictions[key],
loaded_model_predictions[key],
name=key,
tolerance=tolerance,
)

# Check loaded model's loss exists and we can compute gradients, for continuing training.
Expand Down Expand Up @@ -277,7 +286,10 @@ def assert_fields_equal(self, field1, field2, name: str, tolerance: float = 1e-6
assert field1.keys() == field2.keys()
for key in field1:
self.assert_fields_equal(
field1[key], field2[key], tolerance=tolerance, name=name + "." + str(key)
field1[key],
field2[key],
tolerance=tolerance,
name=name + "." + str(key),
)
elif isinstance(field1, (list, tuple)):
assert len(field1) == len(field2)
Expand Down
3 changes: 2 additions & 1 deletion allennlp/common/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class AllenNlpTestCase:

def setup_method(self):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level=logging.DEBUG,
)
# Disabling some of the more verbose logging statements that typically aren't very helpful
# in tests.
Expand Down
5 changes: 4 additions & 1 deletion allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,10 @@ def int_to_device(device: Union[int, torch.device]) -> torch.device:


def log_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> None:
frozen_parameter_names, tunable_parameter_names = get_frozen_and_tunable_parameter_names(model)
(
frozen_parameter_names,
tunable_parameter_names,
) = get_frozen_and_tunable_parameter_names(model)

logger.info("The following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
Expand Down
Loading