Skip to content

Commit

Permalink
Merge pull request #98 from grok-ai/feature/lightning-import
Browse files Browse the repository at this point in the history
Import lightning.pytorch instead of pytorch_lightning
  • Loading branch information
lucmos authored Oct 12, 2023
2 parents e22c1a3 + aef8323 commit 9e25c94
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
10 changes: 5 additions & 5 deletions {{ cookiecutter.repository_name }}/conf/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ monitor:
mode: 'min'

callbacks:
- _target_: pytorch_lightning.callbacks.EarlyStopping
- _target_: lightning.pytorch.callbacks.EarlyStopping
patience: 42
verbose: False
monitor: ${train.monitor.metric}
mode: ${train.monitor.mode}

- _target_: pytorch_lightning.callbacks.ModelCheckpoint
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
save_top_k: 1
verbose: False
monitor: ${train.monitor.metric}
mode: ${train.monitor.mode}

- _target_: pytorch_lightning.callbacks.LearningRateMonitor
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: "step"
log_momentum: False

- _target_: pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar
- _target_: lightning.pytorch.callbacks.progress.tqdm_progress.TQDMProgressBar
refresh_rate: 20

logging:
Expand All @@ -49,7 +49,7 @@ logging:
source: true

logger:
_target_: pytorch_lightning.loggers.WandbLogger
_target_: lightning.pytorch.loggers.WandbLogger

project: ${core.project_name}
entity: null
Expand Down
2 changes: 1 addition & 1 deletion {{ cookiecutter.repository_name }}/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ package_dir=
=src
packages=find:
install_requires =
nn-template-core==0.3.*
nn-template-core==0.4.*
anypy==0.0.*

# Add project specific dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# thus the logging configuration defined in the __init__.py must be called before
# the lightning import otherwise it has no effect.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503
lightning_logger = logging.getLogger("pytorch_lightning")
lightning_logger = logging.getLogger("lightning.pytorch")
# Remove all handlers associated with the lightning logger.
for handler in lightning_logger.handlers[:]:
lightning_logger.removeHandler(handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import List, Mapping, Optional

import hydra
import lightning.pytorch as pl
import omegaconf
import pytorch_lightning as pl
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union

import hydra
import lightning.pytorch as pl
import omegaconf
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import List, Optional

import hydra
import lightning.pytorch as pl
import omegaconf
import pytorch_lightning as pl
import torch
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Callback

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
Expand Down
2 changes: 1 addition & 1 deletion {{ cookiecutter.repository_name }}/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import pytest
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from lightning.pytorch import seed_everything
from omegaconf import DictConfig, OmegaConf, open_dict
from pytest import FixtureRequest, TempPathFactory
from pytorch_lightning import seed_everything

from nn_core.serialization import NNCheckpointIO

Expand Down
4 changes: 2 additions & 2 deletions {{ cookiecutter.repository_name }}/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from pathlib import Path
from typing import Any, Dict

from lightning.pytorch import LightningModule
from lightning.pytorch.core.saving import _load_state
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule
from pytorch_lightning.core.saving import _load_state

from nn_core.serialization import NNCheckpointIO
from tests.conftest import load_checkpoint
Expand Down

0 comments on commit 9e25c94

Please sign in to comment.