Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to add random normal noise to Laplacian filtered image #96

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions ashlar/reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import skimage.util
import skimage.util.dtype
import skimage.io
import skimage.filters
import skimage.exposure
import skimage.transform
import sklearn.linear_model
Expand Down Expand Up @@ -455,7 +456,8 @@ class EdgeAligner(object):

def __init__(
self, reader, channel=0, max_shift=15, false_positive_ratio=0.01,
randomize=False, filter_sigma=0.0, do_make_thumbnail=True, verbose=False
randomize=False, filter_sigma=0.0, add_noise=False,
do_make_thumbnail=True, verbose=False
):
self.channel = channel
self.reader = CachingReader(reader, self.channel)
Expand All @@ -466,6 +468,7 @@ def __init__(
self.false_positive_ratio = false_positive_ratio
self.randomize = randomize
self.filter_sigma = filter_sigma
self.add_noise=add_noise
self.do_make_thumbnail = do_make_thumbnail
self._cache = {}

Expand All @@ -474,6 +477,8 @@ def __init__(
def run(self):
self.make_thumbnail()
self.check_overlaps()
self.find_permutation_pairs()
self.compute_edge_amplitude()
self.compute_threshold()
self.register_all()
self.build_spanning_tree()
Expand Down Expand Up @@ -502,11 +507,7 @@ def check_overlaps(self):
elif any(failures):
warn_data("Some neighboring tiles have zero overlap.")

def compute_threshold(self):
# Compute error threshold for rejecting aligments. We generate a
# distribution of error scores for many known non-overlapping image
# regions and take a certain percentile as the maximum allowable error.
# The percentile becomes our accepted false-positive ratio.
def find_permutation_pairs(self):
edges = self.neighbors_graph.edges
num_tiles = self.metadata.num_images
# If not enough tiles overlap to matter, skip this whole thing.
Expand All @@ -518,16 +519,18 @@ def compute_threshold(self):
self.intersection(t1, t2).shape.min()
for t1, t2 in edges
])
w = widths.max()
self.permutation_width = widths.max()
w = self.permutation_width
max_offset = self.metadata.size[0] - w
# Number of possible pairs minus number of actual neighbor pairs.
num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges)
# Reduce permutation count for small datasets -- there are fewer
# possible truly distinct strips with fewer tiles. The calculation here
# is just a heuristic, not rigorously derived.
n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10
pairs = np.empty((n, 2), dtype=int)
offsets = np.empty((n, 2), dtype=int)

self.permutation_pairs = np.empty((n, 2), dtype=int)
self.permutation_offsets = np.empty((n, 2), dtype=int)
# Generate n random non-overlapping image strips. Strips are always
# horizontal, across the entire image width.
max_tries = 100
Expand Down Expand Up @@ -561,10 +564,49 @@ def compute_threshold(self):
else:
# Retries exhausted. This should be very rare.
warn_data(
"Could not find non-overlapping strips in {max_tries} tries"
f"Could not find non-overlapping strips in {max_tries} tries"
)
pairs[i] = t1, t2
offsets[i] = o1, o2
self.permutation_pairs[i] = t1, t2
self.permutation_offsets[i] = o1, o2

def compute_edge_amplitude(self):
# When the edge amplitudes (Frobenius norm of the laplacian/LoG filtered
# image) in the overlapping pairs are low, our registration will yield
# low errors, which, in general, are not "correct". Here we compute edge
# amplitudes of all the permutation tiles. And find the minimum edge
# amplitude needed to generate confident registration error.
if not self.add_noise:
self.min_edge_amplitude = 0
return
pairs = self.permutation_pairs
offsets = self.permutation_offsets
w = self.permutation_width
n = len(self.permutation_pairs) * 2
edge_amplitudes = np.empty(n)
for i, (t, o) in enumerate(zip(pairs.flatten(), offsets.flatten())):
if self.verbose and (i % 10 == 9 or i == n - 1):
sys.stdout.write(
'\r quantifying edge amplitude %d/%d' % (i + 1, n)
)
sys.stdout.flush()
img = self.reader.read(t, self.channel)[o:o+w, :]
edge_amplitudes[i] = utils.edge_amplitude(img, self.filter_sigma)
if self.verbose: print()
self.edge_amplitudes = edge_amplitudes
# Triangle threshold seems to work in our limited tests
self.min_edge_amplitude = skimage.filters.threshold_triangle(
self.edge_amplitudes
)

def compute_threshold(self):
# Compute error threshold for rejecting aligments. We generate a
# distribution of error scores for many known non-overlapping image
# regions and take a certain percentile as the maximum allowable error.
# The percentile becomes our accepted false-positive ratio.
pairs = self.permutation_pairs
offsets = self.permutation_offsets
w = self.permutation_width
n = len(self.permutation_pairs)
errors = np.empty(n)
for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)):
if self.verbose and (i % 10 == 9 or i == n - 1):
Expand All @@ -574,7 +616,10 @@ def compute_threshold(self):
sys.stdout.flush()
img1 = self.reader.read(t1, self.channel)[offset1:offset1+w, :]
img2 = self.reader.read(t2, self.channel)[offset2:offset2+w, :]
_, errors[i] = utils.register(img1, img2, self.filter_sigma, upsample=1)
_, errors[i] = utils.register(
img1, img2, self.filter_sigma,
upsample=1, noise_factor=self.min_edge_amplitude
)
if self.verbose:
print()
self.errors_negative_sampled = errors
Expand Down Expand Up @@ -687,7 +732,7 @@ def register_pair(self, t1, t2):
# metric on these images. This should be even lower than the error
# computed above.
_, o1, o2 = self.overlap(key[0], key[1], shift=shift)
error = utils.nccw(o1, o2, self.filter_sigma)
error = utils.nccw(o1, o2, self.filter_sigma, self.min_edge_amplitude)
self._cache[key] = (shift, error)
if t1 > t2:
shift = -shift
Expand All @@ -702,7 +747,9 @@ def _register(self, t1, t2, min_size=0):
sx = 1 if p1[1] >= p2[1] else -1
sy = 1 if p1[0] >= p2[0] else -1
padding = its.padding * [sy, sx]
shift, error = utils.register(img1, img2, self.filter_sigma)
shift, error = utils.register(
img1, img2, self.filter_sigma, noise_factor=self.min_edge_amplitude
)
shift += padding
return shift, error

Expand Down
6 changes: 6 additions & 0 deletions ashlar/scripts/ashlar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def main(argv=sys.argv):
'-m', '--maximum-shift', type=float, default=15, metavar='SHIFT',
help='maximum allowed per-tile corrective shift in microns'
)
parser.add_argument(
'--add-noise', default=False, action='store_true',
help=('add noise during stitching')
)
parser.add_argument(
'--filter-sigma', type=float, default=0.0, metavar='SIGMA',
help=('width in pixels of Gaussian filter to apply to images before'
Expand Down Expand Up @@ -157,6 +161,7 @@ def main(argv=sys.argv):
aligner_args['channel'] = args.align_channel
aligner_args['verbose'] = not args.quiet
aligner_args['max_shift'] = args.maximum_shift
aligner_args['add_noise'] = args.add_noise
aligner_args['filter_sigma'] = args.filter_sigma

mosaic_args = {}
Expand Down Expand Up @@ -212,6 +217,7 @@ def process_single(
reader = build_reader(filepaths[0], plate_well=plate_well)
process_axis_flip(reader, flip_x, flip_y)
ea_args = aligner_args.copy()
del aligner_args['add_noise']
if len(filepaths) == 1:
ea_args['do_make_thumbnail'] = False
edge_aligner = reg.EdgeAligner(reader, **ea_args)
Expand Down
135 changes: 135 additions & 0 deletions ashlar/test_apply_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
from ashlar import utils
import skimage.data
import skimage.filters
import skimage.transform

import matplotlib.pyplot as plt
import tqdm

def blob_noise(fraction, noise_sd=0, blob_img_seed=None):
blob_base = skimage.data.binary_blobs(
100,
blob_size_fraction=5/100,
volume_fraction=fraction/(100*100),
seed=blob_img_seed
).astype(float)
rgn = np.random.default_rng()
noise = rgn.normal(0, noise_sd, 100*100).reshape(100, 100)
return blob_base + noise

# radial distortion
def radial_distort(xy, warp_center, k1=0.01, k2=0.002):
assert warp_center in ['left', 'center', 'right']
half = xy.mean(axis=0)
if warp_center == 'right':
center = [0, half[1]]
elif warp_center == 'center':
center = half
elif warp_center == 'left':
center = [2*half[0], half[1]]
xy -= center
xy /= half
r = np.linalg.norm(xy, axis=1)
m_r = 1 + k1*r + k2*r**2
xy /= m_r.reshape(-1, 1)
return xy * half + center


# test_range = np.linspace(0, 100*100, 1000)

# simulates two modes, one mode contains very few objects in overlapping blocks
# and accounts for 30% of total overlaps, rest of the overlaps has good quality
# for phase correlation
test_range = np.sort([
*np.random.default_rng().normal(5, 10, 300),
*np.random.default_rng().normal(100, 20, 700)
])
test_range = test_range[test_range > 0]


# testing and visualizatino function
# left panels shows results from original approach (w/o adding noise to
# laplacian filtered image) while gaussian noise is added to ALL the laplacian
# filtered image in the right panels

# the 1-percentile error cutoff tends to fail when the image has little noise
# (when `NOISE_SD` is low) or/and the distortion is significant (`K2` is high)
def plot_tests(
test_range,
SIGMA=1,
NOISE_SD=0.005,
K1=0.01,
K2=0.002
):

permutation_errors = np.empty(test_range.shape)
edge_amplitudes = np.empty((*test_range.shape, 2))

for idx, i in enumerate(tqdm.tqdm(test_range, desc='Computing edge amp', ascii=True)):
# find edge amplitude threshold using "non-overlapping" blocks
img1 = blob_noise(i, noise_sd=NOISE_SD)
img2 = blob_noise(i, noise_sd=NOISE_SD)
permutation_errors[idx] = utils.register(img1, img2, SIGMA, upsample=1)[1]
edge_amplitudes[idx] = [
utils.edge_amplitude(img1, SIGMA),
utils.edge_amplitude(img2, SIGMA)
]

noise_factor = skimage.filters.threshold_triangle(edge_amplitudes)

permutation_errors_noise = np.empty(test_range.shape)
for idx, i in enumerate(tqdm.tqdm(test_range, desc='Computing errors', ascii=True)):
# calculate permutation errors w/ and w/o added noise
img1 = blob_noise(i, noise_sd=NOISE_SD)
img2 = blob_noise(i, noise_sd=NOISE_SD)
permutation_errors_noise[idx] = utils.register(
img1, img2, SIGMA,
upsample=1, noise_factor=noise_factor
)[1]


errors = np.empty(test_range.shape)
shifts = np.empty((*test_range.shape, 2))

errors_noise = np.empty(test_range.shape)
shifts_noise = np.empty((*test_range.shape, 2))

for idx, i in enumerate(tqdm.tqdm(test_range, desc='Registering imgs', ascii=True)):
# synthesize "overlapping blocks" and add gaussian noise and barrel
# distortion
img1 = blob_noise(i, noise_sd=NOISE_SD, blob_img_seed=1001)
img2 = blob_noise(i, noise_sd=NOISE_SD, blob_img_seed=1001)
img1 = skimage.transform.warp(
img1, radial_distort, map_args=dict(warp_center='left', k1=K1, k2=K2)
)
img2 = skimage.transform.warp(
img2, radial_distort, map_args=dict(warp_center='right', k1=K1, k2=K2)
)

shifts[idx], errors[idx] = utils.register(img1, img2, SIGMA, upsample=1)
shifts_noise[idx], errors_noise[idx] = utils.register(
img1, img2, SIGMA,
upsample=1, noise_factor=noise_factor
)

passed = errors < np.percentile(permutation_errors, 1)
passed_noise = errors_noise < np.percentile(permutation_errors_noise, 1)

fig, axs = plt.subplots(2, 2, sharex=True, sharey=False)
fig.suptitle(f'SIGMA={SIGMA}, NOISE_SD={NOISE_SD}, K1={K1}, K2={K2}')

kwargs = dict(linewidths=0, s=8, alpha=0.5)
axs[0][0].set_title('error w/o noise')
axs[0][0].scatter(test_range, permutation_errors, c='#666666', **kwargs)
axs[0][0].scatter(test_range, errors, c=passed, cmap='PiYG', **kwargs)
axs[0][0].axhline(np.percentile(permutation_errors, 1), c='k', lw=1)
axs[0][1].set_title('error w/ noise')
axs[0][1].scatter(test_range, permutation_errors_noise, c='#666666', **kwargs)
axs[0][1].scatter(test_range, errors_noise, c=passed_noise, cmap='PiYG', **kwargs)
axs[0][1].axhline(np.percentile(permutation_errors_noise, 1), c='k', lw=1)

axs[1][0].set_title('shift distance w/o noise')
axs[1][0].scatter(test_range, np.linalg.norm(shifts, axis=1), c=passed, cmap='PiYG', **kwargs)
axs[1][1].set_title('shift distance w/ noise')
axs[1][1].scatter(test_range, np.linalg.norm(shifts_noise, axis=1), c=passed_noise, cmap='PiYG', **kwargs)
30 changes: 28 additions & 2 deletions ashlar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# Pre-calculate the Laplacian operator kernel. We'll always be using 2D images.
_laplace_kernel = skimage.restoration.uft.laplacian(2, (3, 3))[1]

# Fix random state for random noise generator
_noise_rgn = np.random.default_rng(0)

def whiten(img, sigma):
img = skimage.img_as_float32(img)
if sigma == 0:
Expand All @@ -23,9 +26,27 @@ def whiten(img, sigma):
return output


def register(img1, img2, sigma, upsample=10):
def edge_amplitude(img, sigma):
return np.linalg.norm(
whiten(img, sigma)
)


def add_noise(img, noise_factor):
noise = _noise_rgn.normal(0, 0.1, img.shape)
noise /= np.linalg.norm(noise)
noise *= noise_factor
return img + noise


def register(img1, img2, sigma, upsample=10, noise_factor=0):
img1w = whiten(img1, sigma)
img2w = whiten(img2, sigma)

if noise_factor != 0:
img1w = add_noise(img1w, noise_factor)
img2w = add_noise(img2w, noise_factor)

img1_f = scipy.fft.fft2(img1w)
img2_f = scipy.fft.fft2(img2w)
shift, _error, _phasediff = skimage.feature.register_translation(
Expand Down Expand Up @@ -53,9 +74,14 @@ def register(img1, img2, sigma, upsample=10):
return shift, error


def nccw(img1, img2, sigma):
def nccw(img1, img2, sigma, noise_factor=0):
img1w = whiten(img1, sigma)
img2w = whiten(img2, sigma)

if noise_factor != 0:
img1w = add_noise(img1w, noise_factor)
img2w = add_noise(img2w, noise_factor)

correlation = np.abs(np.sum(img1w * img2w))
total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w)
if correlation > 0 and total_amplitude > 0:
Expand Down