Skip to content

Commit

Permalink
Docstrings revamp (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
tesfaldet committed Jul 29, 2023
1 parent b987605 commit e70291d
Show file tree
Hide file tree
Showing 20 changed files with 450 additions and 160 deletions.
30 changes: 27 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,33 @@ repos:

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.5.1
rev: v1.7.4
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
args:
[
--in-place,
--wrap-summaries=99,
--wrap-descriptions=99,
--style=sphinx,
--black,
]

# python docstring coverage checking
- repo: https://github.com/econchick/interrogate
rev: 1.5.0 # or master if you're bold
hooks:
- id: interrogate
args:
[
--verbose,
--fail-under=80,
--ignore-init-module,
--ignore-init-method,
--ignore-module,
--ignore-nested-functions,
-vv,
]

# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
Expand All @@ -53,10 +76,11 @@ repos:
args:
[
"--extend-ignore",
"E203,E402,E501,F401,F841",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]

# python security linter
- repo: https://github.com/PyCQA/bandit
Expand Down
117 changes: 89 additions & 28 deletions src/data/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,42 @@


class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
"""`LightningDataModule` for the MNIST dataset.
A DataModule implements 6 key methods:
The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
A `LightningDataModule` implements 7 key methods:
```python
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
# return train dataloader
def val_dataloader(self):
# return validation dataloader
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
```
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Expand All @@ -41,7 +59,15 @@ def __init__(
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
):
) -> None:
"""Initialize a `MNISTDataModule`.
:param data_dir: The data directory. Defaults to `"data/"`.
:param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
:param batch_size: The batch size. Defaults to `64`.
:param num_workers: The number of workers. Defaults to `0`.
:param pin_memory: Whether to pin memory. Defaults to `False`.
"""
super().__init__()

# this line allows to access init params with 'self.hparams' attribute
Expand All @@ -58,22 +84,33 @@ def __init__(
self.data_test: Optional[Dataset] = None

@property
def num_classes(self):
def num_classes(self) -> int:
"""Get the number of classes.
:return: The number of MNIST classes (10).
"""
return 10

def prepare_data(self):
"""Download data if needed.
def prepare_data(self) -> None:
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only
within a single process on CPU, so you can safely add your downloading logic within. In
case of multi-node training, the execution of this hook depends upon
`self.prepare_data_per_node()`.
Do not use it to assign state (self.x = y).
"""
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):
def setup(self, stage: Optional[str] = None) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
`self.setup()` once the data is prepared and available for use.
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
Expand All @@ -86,7 +123,11 @@ def setup(self, stage: Optional[str] = None):
generator=torch.Generator().manual_seed(42),
)

def train_dataloader(self):
def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
:return: The train dataloader.
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
Expand All @@ -95,7 +136,11 @@ def train_dataloader(self):
shuffle=True,
)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
:return: The validation dataloader.
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
Expand All @@ -104,7 +149,11 @@ def val_dataloader(self):
shuffle=False,
)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader[Any]:
"""Create and return the test dataloader.
:return: The test dataloader.
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
Expand All @@ -113,16 +162,28 @@ def test_dataloader(self):
shuffle=False,
)

def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
def teardown(self, stage: Optional[str] = None) -> None:
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
`trainer.test()`, and `trainer.predict()`.
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
Defaults to ``None``.
"""
pass

def state_dict(self):
"""Extra things to save to checkpoint."""
def state_dict(self) -> Dict[Any, Any]:
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
:return: A dictionary containing the datamodule state that you want to save.
"""
return {}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
`state_dict()`.
:param state_dict: The datamodule state returned by `self.state_dict()`.
"""
pass


Expand Down
16 changes: 8 additions & 8 deletions src/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

import hydra
import pyrootutils
Expand Down Expand Up @@ -30,19 +30,15 @@


@utils.task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""

assert cfg.ckpt_path

log.info(f"Instantiating datamodule <{cfg.data._target_}>")
Expand Down Expand Up @@ -82,6 +78,10 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:

@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
Expand Down
20 changes: 18 additions & 2 deletions src/models/components/simple_dense_net.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import torch
from torch import nn


class SimpleDenseNet(nn.Module):
"""A simple fully-connected neural net for computing predictions."""

def __init__(
self,
input_size: int = 784,
lin1_size: int = 256,
lin2_size: int = 256,
lin3_size: int = 256,
output_size: int = 10,
):
) -> None:
"""Initialize a `SimpleDenseNet` module.
:param input_size: The number of input features.
:param lin1_size: The number of output features of the first linear layer.
:param lin2_size: The number of output features of the second linear layer.
:param lin3_size: The number of output features of the third linear layer.
:param output_size: The number of output features of the final linear layer.
"""
super().__init__()

self.model = nn.Sequential(
Expand All @@ -25,7 +36,12 @@ def __init__(
nn.Linear(lin3_size, output_size),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a single forward pass through the network.
:param x: The input tensor.
:return: A tensor of predictions.
"""
batch_size, channels, width, height = x.size()

# (batch, 1, width, height) -> (batch, 1*width*height)
Expand Down
Loading

0 comments on commit e70291d

Please sign in to comment.