diff --git a/ashlar/reg.py b/ashlar/reg.py index 718085f9..a2e57047 100644 --- a/ashlar/reg.py +++ b/ashlar/reg.py @@ -13,6 +13,7 @@ import skimage.util import skimage.util.dtype import skimage.io +import skimage.filters import skimage.exposure import skimage.transform import sklearn.linear_model @@ -455,7 +456,8 @@ class EdgeAligner(object): def __init__( self, reader, channel=0, max_shift=15, false_positive_ratio=0.01, - randomize=False, filter_sigma=0.0, do_make_thumbnail=True, verbose=False + randomize=False, filter_sigma=0.0, add_noise=False, + do_make_thumbnail=True, verbose=False ): self.channel = channel self.reader = CachingReader(reader, self.channel) @@ -466,6 +468,7 @@ def __init__( self.false_positive_ratio = false_positive_ratio self.randomize = randomize self.filter_sigma = filter_sigma + self.add_noise=add_noise self.do_make_thumbnail = do_make_thumbnail self._cache = {} @@ -474,6 +477,8 @@ def __init__( def run(self): self.make_thumbnail() self.check_overlaps() + self.find_permutation_pairs() + self.compute_edge_amplitude() self.compute_threshold() self.register_all() self.build_spanning_tree() @@ -502,11 +507,7 @@ def check_overlaps(self): elif any(failures): warn_data("Some neighboring tiles have zero overlap.") - def compute_threshold(self): - # Compute error threshold for rejecting aligments. We generate a - # distribution of error scores for many known non-overlapping image - # regions and take a certain percentile as the maximum allowable error. - # The percentile becomes our accepted false-positive ratio. + def find_permutation_pairs(self): edges = self.neighbors_graph.edges num_tiles = self.metadata.num_images # If not enough tiles overlap to matter, skip this whole thing. @@ -518,7 +519,8 @@ def compute_threshold(self): self.intersection(t1, t2).shape.min() for t1, t2 in edges ]) - w = widths.max() + self.permutation_width = widths.max() + w = self.permutation_width max_offset = self.metadata.size[0] - w # Number of possible pairs minus number of actual neighbor pairs. num_distant_pairs = num_tiles * (num_tiles - 1) // 2 - len(edges) @@ -526,8 +528,9 @@ def compute_threshold(self): # possible truly distinct strips with fewer tiles. The calculation here # is just a heuristic, not rigorously derived. n = 1000 if num_distant_pairs > 8 else (num_distant_pairs + 1) * 10 - pairs = np.empty((n, 2), dtype=int) - offsets = np.empty((n, 2), dtype=int) + + self.permutation_pairs = np.empty((n, 2), dtype=int) + self.permutation_offsets = np.empty((n, 2), dtype=int) # Generate n random non-overlapping image strips. Strips are always # horizontal, across the entire image width. max_tries = 100 @@ -561,10 +564,49 @@ def compute_threshold(self): else: # Retries exhausted. This should be very rare. warn_data( - "Could not find non-overlapping strips in {max_tries} tries" + f"Could not find non-overlapping strips in {max_tries} tries" ) - pairs[i] = t1, t2 - offsets[i] = o1, o2 + self.permutation_pairs[i] = t1, t2 + self.permutation_offsets[i] = o1, o2 + + def compute_edge_amplitude(self): + # When the edge amplitudes (Frobenius norm of the laplacian/LoG filtered + # image) in the overlapping pairs are low, our registration will yield + # low errors, which, in general, are not "correct". Here we compute edge + # amplitudes of all the permutation tiles. And find the minimum edge + # amplitude needed to generate confident registration error. + if not self.add_noise: + self.min_edge_amplitude = 0 + return + pairs = self.permutation_pairs + offsets = self.permutation_offsets + w = self.permutation_width + n = len(self.permutation_pairs) * 2 + edge_amplitudes = np.empty(n) + for i, (t, o) in enumerate(zip(pairs.flatten(), offsets.flatten())): + if self.verbose and (i % 10 == 9 or i == n - 1): + sys.stdout.write( + '\r quantifying edge amplitude %d/%d' % (i + 1, n) + ) + sys.stdout.flush() + img = self.reader.read(t, self.channel)[o:o+w, :] + edge_amplitudes[i] = utils.edge_amplitude(img, self.filter_sigma) + if self.verbose: print() + self.edge_amplitudes = edge_amplitudes + # Triangle threshold seems to work in our limited tests + self.min_edge_amplitude = skimage.filters.threshold_triangle( + self.edge_amplitudes + ) + + def compute_threshold(self): + # Compute error threshold for rejecting aligments. We generate a + # distribution of error scores for many known non-overlapping image + # regions and take a certain percentile as the maximum allowable error. + # The percentile becomes our accepted false-positive ratio. + pairs = self.permutation_pairs + offsets = self.permutation_offsets + w = self.permutation_width + n = len(self.permutation_pairs) errors = np.empty(n) for i, ((t1, t2), (offset1, offset2)) in enumerate(zip(pairs, offsets)): if self.verbose and (i % 10 == 9 or i == n - 1): @@ -574,7 +616,10 @@ def compute_threshold(self): sys.stdout.flush() img1 = self.reader.read(t1, self.channel)[offset1:offset1+w, :] img2 = self.reader.read(t2, self.channel)[offset2:offset2+w, :] - _, errors[i] = utils.register(img1, img2, self.filter_sigma, upsample=1) + _, errors[i] = utils.register( + img1, img2, self.filter_sigma, + upsample=1, noise_factor=self.min_edge_amplitude + ) if self.verbose: print() self.errors_negative_sampled = errors @@ -687,7 +732,7 @@ def register_pair(self, t1, t2): # metric on these images. This should be even lower than the error # computed above. _, o1, o2 = self.overlap(key[0], key[1], shift=shift) - error = utils.nccw(o1, o2, self.filter_sigma) + error = utils.nccw(o1, o2, self.filter_sigma, self.min_edge_amplitude) self._cache[key] = (shift, error) if t1 > t2: shift = -shift @@ -702,7 +747,9 @@ def _register(self, t1, t2, min_size=0): sx = 1 if p1[1] >= p2[1] else -1 sy = 1 if p1[0] >= p2[0] else -1 padding = its.padding * [sy, sx] - shift, error = utils.register(img1, img2, self.filter_sigma) + shift, error = utils.register( + img1, img2, self.filter_sigma, noise_factor=self.min_edge_amplitude + ) shift += padding return shift, error diff --git a/ashlar/scripts/ashlar.py b/ashlar/scripts/ashlar.py index 80975442..9d7065bc 100644 --- a/ashlar/scripts/ashlar.py +++ b/ashlar/scripts/ashlar.py @@ -57,6 +57,10 @@ def main(argv=sys.argv): '-m', '--maximum-shift', type=float, default=15, metavar='SHIFT', help='maximum allowed per-tile corrective shift in microns' ) + parser.add_argument( + '--add-noise', default=False, action='store_true', + help=('add noise during stitching') + ) parser.add_argument( '--filter-sigma', type=float, default=0.0, metavar='SIGMA', help=('width in pixels of Gaussian filter to apply to images before' @@ -157,6 +161,7 @@ def main(argv=sys.argv): aligner_args['channel'] = args.align_channel aligner_args['verbose'] = not args.quiet aligner_args['max_shift'] = args.maximum_shift + aligner_args['add_noise'] = args.add_noise aligner_args['filter_sigma'] = args.filter_sigma mosaic_args = {} @@ -212,6 +217,7 @@ def process_single( reader = build_reader(filepaths[0], plate_well=plate_well) process_axis_flip(reader, flip_x, flip_y) ea_args = aligner_args.copy() + del aligner_args['add_noise'] if len(filepaths) == 1: ea_args['do_make_thumbnail'] = False edge_aligner = reg.EdgeAligner(reader, **ea_args) diff --git a/ashlar/test_apply_noise.py b/ashlar/test_apply_noise.py new file mode 100644 index 00000000..0614a191 --- /dev/null +++ b/ashlar/test_apply_noise.py @@ -0,0 +1,135 @@ +import numpy as np +from ashlar import utils +import skimage.data +import skimage.filters +import skimage.transform + +import matplotlib.pyplot as plt +import tqdm + +def blob_noise(fraction, noise_sd=0, blob_img_seed=None): + blob_base = skimage.data.binary_blobs( + 100, + blob_size_fraction=5/100, + volume_fraction=fraction/(100*100), + seed=blob_img_seed + ).astype(float) + rgn = np.random.default_rng() + noise = rgn.normal(0, noise_sd, 100*100).reshape(100, 100) + return blob_base + noise + +# radial distortion +def radial_distort(xy, warp_center, k1=0.01, k2=0.002): + assert warp_center in ['left', 'center', 'right'] + half = xy.mean(axis=0) + if warp_center == 'right': + center = [0, half[1]] + elif warp_center == 'center': + center = half + elif warp_center == 'left': + center = [2*half[0], half[1]] + xy -= center + xy /= half + r = np.linalg.norm(xy, axis=1) + m_r = 1 + k1*r + k2*r**2 + xy /= m_r.reshape(-1, 1) + return xy * half + center + + +# test_range = np.linspace(0, 100*100, 1000) + +# simulates two modes, one mode contains very few objects in overlapping blocks +# and accounts for 30% of total overlaps, rest of the overlaps has good quality +# for phase correlation +test_range = np.sort([ + *np.random.default_rng().normal(5, 10, 300), + *np.random.default_rng().normal(100, 20, 700) +]) +test_range = test_range[test_range > 0] + + +# testing and visualizatino function +# left panels shows results from original approach (w/o adding noise to +# laplacian filtered image) while gaussian noise is added to ALL the laplacian +# filtered image in the right panels + +# the 1-percentile error cutoff tends to fail when the image has little noise +# (when `NOISE_SD` is low) or/and the distortion is significant (`K2` is high) +def plot_tests( + test_range, + SIGMA=1, + NOISE_SD=0.005, + K1=0.01, + K2=0.002 +): + + permutation_errors = np.empty(test_range.shape) + edge_amplitudes = np.empty((*test_range.shape, 2)) + + for idx, i in enumerate(tqdm.tqdm(test_range, desc='Computing edge amp', ascii=True)): + # find edge amplitude threshold using "non-overlapping" blocks + img1 = blob_noise(i, noise_sd=NOISE_SD) + img2 = blob_noise(i, noise_sd=NOISE_SD) + permutation_errors[idx] = utils.register(img1, img2, SIGMA, upsample=1)[1] + edge_amplitudes[idx] = [ + utils.edge_amplitude(img1, SIGMA), + utils.edge_amplitude(img2, SIGMA) + ] + + noise_factor = skimage.filters.threshold_triangle(edge_amplitudes) + + permutation_errors_noise = np.empty(test_range.shape) + for idx, i in enumerate(tqdm.tqdm(test_range, desc='Computing errors', ascii=True)): + # calculate permutation errors w/ and w/o added noise + img1 = blob_noise(i, noise_sd=NOISE_SD) + img2 = blob_noise(i, noise_sd=NOISE_SD) + permutation_errors_noise[idx] = utils.register( + img1, img2, SIGMA, + upsample=1, noise_factor=noise_factor + )[1] + + + errors = np.empty(test_range.shape) + shifts = np.empty((*test_range.shape, 2)) + + errors_noise = np.empty(test_range.shape) + shifts_noise = np.empty((*test_range.shape, 2)) + + for idx, i in enumerate(tqdm.tqdm(test_range, desc='Registering imgs', ascii=True)): + # synthesize "overlapping blocks" and add gaussian noise and barrel + # distortion + img1 = blob_noise(i, noise_sd=NOISE_SD, blob_img_seed=1001) + img2 = blob_noise(i, noise_sd=NOISE_SD, blob_img_seed=1001) + img1 = skimage.transform.warp( + img1, radial_distort, map_args=dict(warp_center='left', k1=K1, k2=K2) + ) + img2 = skimage.transform.warp( + img2, radial_distort, map_args=dict(warp_center='right', k1=K1, k2=K2) + ) + + shifts[idx], errors[idx] = utils.register(img1, img2, SIGMA, upsample=1) + shifts_noise[idx], errors_noise[idx] = utils.register( + img1, img2, SIGMA, + upsample=1, noise_factor=noise_factor + ) + + passed = errors < np.percentile(permutation_errors, 1) + passed_noise = errors_noise < np.percentile(permutation_errors_noise, 1) + + fig, axs = plt.subplots(2, 2, sharex=True, sharey=False) + fig.suptitle(f'SIGMA={SIGMA}, NOISE_SD={NOISE_SD}, K1={K1}, K2={K2}') + + kwargs = dict(linewidths=0, s=8, alpha=0.5) + axs[0][0].set_title('error w/o noise') + axs[0][0].scatter(test_range, permutation_errors, c='#666666', **kwargs) + axs[0][0].scatter(test_range, errors, c=passed, cmap='PiYG', **kwargs) + axs[0][0].axhline(np.percentile(permutation_errors, 1), c='k', lw=1) + axs[0][1].set_title('error w/ noise') + axs[0][1].scatter(test_range, permutation_errors_noise, c='#666666', **kwargs) + axs[0][1].scatter(test_range, errors_noise, c=passed_noise, cmap='PiYG', **kwargs) + axs[0][1].axhline(np.percentile(permutation_errors_noise, 1), c='k', lw=1) + + axs[1][0].set_title('shift distance w/o noise') + axs[1][0].scatter(test_range, np.linalg.norm(shifts, axis=1), c=passed, cmap='PiYG', **kwargs) + axs[1][1].set_title('shift distance w/ noise') + axs[1][1].scatter(test_range, np.linalg.norm(shifts_noise, axis=1), c=passed_noise, cmap='PiYG', **kwargs) diff --git a/ashlar/utils.py b/ashlar/utils.py index 7372ef32..db1b1a37 100644 --- a/ashlar/utils.py +++ b/ashlar/utils.py @@ -14,6 +14,9 @@ # Pre-calculate the Laplacian operator kernel. We'll always be using 2D images. _laplace_kernel = skimage.restoration.uft.laplacian(2, (3, 3))[1] +# Fix random state for random noise generator +_noise_rgn = np.random.default_rng(0) + def whiten(img, sigma): img = skimage.img_as_float32(img) if sigma == 0: @@ -23,9 +26,27 @@ def whiten(img, sigma): return output -def register(img1, img2, sigma, upsample=10): +def edge_amplitude(img, sigma): + return np.linalg.norm( + whiten(img, sigma) + ) + + +def add_noise(img, noise_factor): + noise = _noise_rgn.normal(0, 0.1, img.shape) + noise /= np.linalg.norm(noise) + noise *= noise_factor + return img + noise + + +def register(img1, img2, sigma, upsample=10, noise_factor=0): img1w = whiten(img1, sigma) img2w = whiten(img2, sigma) + + if noise_factor != 0: + img1w = add_noise(img1w, noise_factor) + img2w = add_noise(img2w, noise_factor) + img1_f = scipy.fft.fft2(img1w) img2_f = scipy.fft.fft2(img2w) shift, _error, _phasediff = skimage.feature.register_translation( @@ -53,9 +74,14 @@ def register(img1, img2, sigma, upsample=10): return shift, error -def nccw(img1, img2, sigma): +def nccw(img1, img2, sigma, noise_factor=0): img1w = whiten(img1, sigma) img2w = whiten(img2, sigma) + + if noise_factor != 0: + img1w = add_noise(img1w, noise_factor) + img2w = add_noise(img2w, noise_factor) + correlation = np.abs(np.sum(img1w * img2w)) total_amplitude = np.linalg.norm(img1w) * np.linalg.norm(img2w) if correlation > 0 and total_amplitude > 0: