Skip to content

Commit

Permalink
Always use the local rank zero imports (#16178)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Dec 24, 2022
1 parent 574a951 commit 612d43e
Show file tree
Hide file tree
Showing 27 changed files with 30 additions and 42 deletions.
3 changes: 2 additions & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from lightning_lite.utilities.rank_zero import rank_zero_warn

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/plugins/environments/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import sys
from typing import Optional

from lightning_utilities.core.rank_zero import rank_zero_warn

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.imports import _IS_WINDOWS
from lightning_lite.utilities.rank_zero import rank_zero_warn
from lightning_lite.utilities.warnings import PossibleUserWarning

log = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only
from torch.nn import Module
from torch.optim import Optimizer

Expand All @@ -33,7 +32,7 @@
from lightning_lite.strategies.strategy import _Sharded
from lightning_lite.utilities.distributed import log
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info
from lightning_lite.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _PATH

Expand Down
2 changes: 2 additions & 0 deletions src/lightning_lite/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

# note: we want to keep these indirections so the `rank_zero_only.rank` is set on import
from lightning_utilities.core.rank_zero import ( # noqa: F401
rank_prefixed_message,
rank_zero_debug,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
WarningCache,
)

import lightning_lite
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import numpy as np
import torch
from lightning_utilities.core.rank_zero import rank_prefixed_message

from lightning_lite.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
from lightning_lite.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, Dict

import torch
from lightning_utilities.core.rank_zero import rank_zero_deprecation

import pytorch_lightning as pl
from lightning_lite.accelerators.accelerator import Accelerator as _Accelerator
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class Accelerator(_Accelerator, ABC):
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@

import numpy as np
import torch
from lightning_utilities.core.rank_zero import rank_prefixed_message
from torch import Tensor

import pytorch_lightning as pl
from lightning_lite.utilities.rank_zero import _get_rank
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@
import numpy as np
import torch
import yaml
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand All @@ -49,7 +48,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import (
_METRIC_COLLECTION,
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from abc import ABC
from typing import List, Optional, Tuple, Union

from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn

from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
from lightning_lite.connector import _PRECISION_INPUT
from lightning_lite.lite import LightningLite as _NewLightningLite
Expand Down Expand Up @@ -52,6 +50,7 @@
from pytorch_lightning.strategies import SingleTPUStrategy as PLSingleTPUStrategy
from pytorch_lightning.strategies import Strategy as PLStrategy
from pytorch_lightning.strategies import TPUSpawnStrategy as PLTPUSpawnStrategy
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

_PL_PLUGIN = Union[PLPrecisionPlugin, ClusterEnvironment, CheckpointIO]
_PL_PLUGIN_INPUT = Union[_PL_PLUGIN, str]
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, Iterator, List, Tuple

import torch
from lightning_utilities.core.rank_zero import WarningCache

from lightning_lite.utilities import move_data_to_device
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache

import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
Expand All @@ -32,7 +31,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import OrderedDict
Expand All @@ -34,7 +33,7 @@
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
from typing import Any, Callable, Optional, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import LBFGS, Optimizer

Expand All @@ -26,7 +25,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, WarningCache

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from typing import Any, Callable, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import LBFGS, Optimizer

Expand All @@ -24,6 +23,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.rank_zero import WarningCache
from torch import nn, Tensor
from torch.autograd.profiler import record_function

from lightning_lite.accelerators.cuda import is_cuda_available
from pytorch_lightning.profilers.profiler import Profiler
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache

if TYPE_CHECKING:
from torch.autograd.profiler import EventList
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand All @@ -35,6 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT

_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
Expand All @@ -45,7 +44,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from weakref import proxy

from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

Expand All @@ -35,7 +34,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torchmetrics import Metric
from typing_extensions import TypedDict
Expand All @@ -30,7 +29,7 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning

_IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from enum import Enum, EnumMeta
from typing import Any, List, Optional

from lightning_utilities.core.rank_zero import rank_zero_deprecation

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class _DeprecationManagingEnumMeta(EnumMeta):
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from lightning_utilities.core.apply_func import is_dataclass_instance
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.utils.data import (
BatchSampler,
Expand All @@ -39,7 +38,7 @@
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn, WarningCache

# might be supported in later releases, see https://github.com/python/mypy/pull/13297
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
import re
from typing import Any, Callable, Dict, List

from lightning_utilities.core.rank_zero import rank_zero_warn

from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

_CHECKPOINT = Dict[str, Any]

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from types import ModuleType, TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type

from lightning_utilities.core.rank_zero import rank_zero_warn
from packaging.version import Version

import pytorch_lightning as pl
from lightning_lite.utilities.imports import _IS_WINDOWS
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.utilities.migration.migration import _migration_index
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

_log = logging.getLogger(__name__)
_CHECKPOINT = Dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import numpy as np
import torch
import torch.nn as nn
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import WarningCache

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down
Loading

0 comments on commit 612d43e

Please sign in to comment.