diff --git a/src/compare.py b/src/compare.py index 699e92eb..76cf5db1 100644 --- a/src/compare.py +++ b/src/compare.py @@ -72,6 +72,8 @@ def compare_l2_norm_masked(source, capture, mask) -> float: # The L2 Error is summed across all pixels, so this normalizes max_error = (3 * np.count_nonzero(mask) * 255 * 255) ** 0.5 + if not max_error: + return 0 return 1 - (error / max_error) def compare_template(source, capture) -> float: @@ -155,10 +157,20 @@ def compare_phash_masked(source, capture, mask): source_hash = imagehash.phash(source) capture_hash = imagehash.phash(capture) + if not source_hash + capture_hash: + return 0 return 1 - ((source_hash - capture_hash) / 64.0) def checkIfImageHasTransparency(image): - # TODO check for first transparent pixel, no need to iterate through the whole image # Check if there's a transparency channel (4th channel) and if at least one pixel is transparent (< 255) - return image.shape[2] == 4 and np.mean(image[:, :, 3]) != 255 + if image.shape[2] != 4: + return False + mean = np.mean(image[:, :, 3]) + if mean != 0: + # Non-transparent images code path is usually faster and simpler, so let's return that + return False + # TODO error message if all pixels are transparent + # (the image appears as all black in windows, so it's not obvious for the user what they did wrong) + + return mean != 255