Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust formatting for SSL project code #39

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
12 changes: 1 addition & 11 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def summary(self, additional_vars=None):
super().summary(summary_vars)


class ContextUNetConfig(Config):
class ContextUNetConfig(ContextEncoderConfig):
"""
Configuration for the ContextUNet model.

Expand All @@ -769,16 +769,6 @@ class ContextUNetConfig(Config):
"""

MODEL_NAME = "ContextUNet"
NUM_FILTERS = [[32, 32], [64, 64], [128, 128], [256, 256]]

def __init__(self, state="training", create_dirs=True):
super().__init__(self.MODEL_NAME, state, create_dirs=create_dirs)

def summary(self, additional_vars=None):
summary_vars = ["NUM_FILTERS"]
if additional_vars:
summary_vars.extend(additional_vars)
super().summary(summary_vars)


class ContextInpaintingConfig(ContextUNetConfig):
Expand Down
9 changes: 3 additions & 6 deletions medsegpy/cross_validation/cv_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,9 @@ def init_cv_experiments(self, num_valid_bins=1, num_test_bins=1):

for i in range(len(temp)):
for j in range(i + 1, len(temp)):
assert (
len(set(temp[i]) & set(temp[j])) == 0
), "Test bins %d and %d not mutually exclusive - %d overlap" % (
i,
j,
len(set(temp[i]) & set(temp[j])),
assert len(set(temp[i]) & set(temp[j])) == 0, (
"Test bins %d and %d not mutually exclusive - %d overlap"
% (i, j, len(set(temp[i]) & set(temp[j])))
)

self.num_valid_bins = num_valid_bins
Expand Down
4 changes: 3 additions & 1 deletion medsegpy/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Sequence, Union

import numpy as np
from numba import njit


def collect_mask(mask: np.ndarray, index: Sequence[Union[int, Sequence[int], int]]):
Expand Down Expand Up @@ -195,7 +196,7 @@ def generate_poisson_disc_mask(
x /= x.max()
y = np.maximum(abs(y - img_shape[-2] / 2), 0)
y /= y.max()
r = np.sqrt(x**2 + y**2)
r = np.sqrt(x ** 2 + y ** 2)

# Quick checks
assert int(num_samples) == num_samples, (
Expand Down Expand Up @@ -233,6 +234,7 @@ def generate_poisson_disc_mask(
return mask, patch_mask


@njit
def _poisson(nx, ny, K, R, num_samples=None, patch_size=0.0, seed=None):
mask = np.zeros((ny, nx))
patch_mask = np.zeros((ny, nx))
Expand Down
2 changes: 2 additions & 0 deletions medsegpy/data/datasets/abct.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,6 @@ def register_all_abct():
txt_file_or_scan_root = os.path.join(
Cluster.working_cluster().data_dir, txt_file_or_scan_root
)
if not os.path.exists(txt_file_or_scan_root):
continue
register_abct(dataset_name, txt_file_or_scan_root)
2 changes: 2 additions & 0 deletions medsegpy/data/datasets/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,6 @@ def register_all_oai():
for dataset_name, scan_root in _DATA_CATALOG.items():
if not os.path.isabs(scan_root):
scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root)
if not os.path.exists(scan_root):
continue
register_oai(dataset_name, scan_root)
6 changes: 6 additions & 0 deletions medsegpy/data/datasets/qdess_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re

from medsegpy.data.catalog import DatasetCatalog, MetadataCatalog
from medsegpy.utils.cluster import Cluster

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,6 +162,7 @@ def load_2d_from_filepaths(filepaths: list, source_path: str, dataset_name: str
corresponding ground truth segmentations.
total_num_slices: The total number of slices for this dataset.
dataset_name: The name of the dataset.

Returns:
dataset_dicts: A list of dictionaries, described above in the
docstring.
Expand Down Expand Up @@ -336,4 +338,8 @@ def register_all_qdess_datasets():
Registers all qDESS MRI datasets listed in _DATA_CATALOG.
"""
for dataset_name, scan_root in _DATA_CATALOG.items():
if not os.path.isabs(scan_root):
scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root)
if not os.path.exists(scan_root):
continue
register_qdess_dataset(scan_root=scan_root, dataset_name=dataset_name)
8 changes: 3 additions & 5 deletions medsegpy/data/im_gens.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,11 +1234,9 @@ def __validate_img_size__(self, total_volume_shape):
# this means shape of total volume must be perfectly divisible into
# cubes of size IMG_SIZE
for dim in range(3):
assert (
total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0
), "Cannot divide volume of size %s to blocks of size %s" % (
total_volume_shape,
self.config.IMG_SIZE,
assert total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0, (
"Cannot divide volume of size %s to blocks of size %s"
% (total_volume_shape, self.config.IMG_SIZE)
)

def img_generator_test(self, model=None):
Expand Down
5 changes: 5 additions & 0 deletions medsegpy/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def apply_image(self, img: np.ndarray):
img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be
of type uint8 in range [0, 255], or floating point in range
[0, 1] or [0, 255].

Returns:
ndarray: image after apply the transformation.
"""
Expand Down Expand Up @@ -147,6 +148,7 @@ def _apply(self, x: _T, meth: str) -> _T:
Args:
x: input to apply the transform operations.
meth (str): meth.

Returns:
x: after apply the transformation.
"""
Expand All @@ -167,6 +169,7 @@ def __add__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand All @@ -177,6 +180,7 @@ def __iadd__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand All @@ -188,6 +192,7 @@ def __radd__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand Down
Loading