From bd15413117c85d2e6acb6ed563c9c29c4f9528b9 Mon Sep 17 00:00:00 2001 From: Yasyf Mohamedali Date: Thu, 23 Feb 2023 18:51:02 -0800 Subject: [PATCH] Remove deprecated `torch._six` imports (#2863) * Remove deprecated `torch._six` imports Closes #2845. * Support older versions of PyTorch as well. --------- Co-authored-by: Jeff Rasley Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/utils.py | 6 +++++- deepspeed/runtime/zero/stage3.py | 3 +-- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 3e2d454e6e63..30dad84b16d1 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -16,9 +16,13 @@ from bisect import bisect_left import torch -from torch._six import inf from deepspeed import comm as dist +try: + from torch._six import inf as inf +except ModuleNotFoundError: + from torch import inf as inf + from deepspeed.utils import groups, logger from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d56898efa20f..084f29c27922 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,13 +7,12 @@ import gc import collections from typing import Deque, Dict, Tuple -from torch._six import inf from deepspeed.runtime import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, is_model_parallel_parameter +from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0b9d9c5a6fef..8c980d6a75c1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -5,7 +5,6 @@ import torch import os from deepspeed import comm as dist -from torch._six import inf from packaging import version as pkg_version from collections import OrderedDict @@ -15,6 +14,7 @@ get_global_norm, empty_cache, see_memory_usage, + inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)