Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoiAlign CPU is not aligned to pixel centers (per the Mask RCNN paper and Facebook's Detectron2 implementation) #6921

Open
fdwr opened this issue Mar 6, 2021 · 4 comments
Assignees

Comments

@fdwr
Copy link
Contributor

fdwr commented Mar 6, 2021

Describe the bug
The RoiAlign operator, per the Mask RCNN paper and Facebook Research's Detectron 2 implementation aligns sampling points over the center of the pixels, but ORT's CPU implementation is misaligned by a half pixel. After comparing ORT to various references (table below), I see current ORT code duplicated PyTorch's earlier bug in roi_align which applied an offset the output subsample by 0.5 but forgot to adjust the input sample to compensate (see their comment in the code: "the original roi_align (aligned=False) does not subtract the 0.5 when computing neighboring pixel indices and therefore it uses pixels with a slightly incorrect alignment (relative to our pixel model) when performing bilinear interpolation").

From the paper, note pixel centers used for interpolation:
image

This isn't as evident for larger input image regions, where that misalignment becomes less important relative to the overall region size, but it makes quite a difference for smaller regions. Even identity cases are misaligned (where the region of interest exactly matches the output tensor size). e.g. Taking the middle 2x2 slice of a 4x4 input to a 2x2 output (integer coordinates, no scale factor) should yield exactly that input slice, but ORT's result are shifted half a pixel off.

Urgency
No deadline.

System information

  • OS Platform and Distribution: NA, but Windows 10 recent selfhost
  • ONNX Runtime installed from (source or binary): source
  • ONNX Runtime version: 1.7
  • Python version: NA
  • Visual Studio version (if applicable): VS2019
  • GCC/Compiler version (if compiling from source): NA
  • CUDA/cuDNN version: NA
  • GPU model and memory: NA

To Reproduce

Expected behavior

  • For the identity test case:
    • Expected output: [[[[11, 12], [21, 22]]]]
    • Actual output: [[[[5.50, 5.75], [8.00, 8.25]]]]
  • For the detectron test case:
    • Expected output: [[[[ 8.25, 8.75, 9.25, 9.75], [13.25, 13.75, 14.25, 14.75], [18.25, 18.75, 19.25, 19.75], [23.25, 23.75, 24.25, 24.75]]]]
    • Actual output: [[[[6.1875, 6.75, 6.75, 7.3125], [11.8125, 12.375, 12.375, 12.9375], [11.8125, 12.375, 12.375, 12.9375], [17.4375, 18, 18, 18.5625]]]]

Screenshots
e.g.
image

Additional context

This affects the faster_rcnn and mask_rcnn models in WinML, for which the expected output results appear to have been recorded using the incorrect alignment via CPU in the first place, whereas DML follows half pixel alignment (matching Detectron 2) and gets different results than the output .PB files.

For an example case (modified from the Detectron test case), and comparison to other framework results:

Input Tensor =
                   <------->
              0.0 1.0 2.0 3.0 4.0 5.0 6.0
               |.5 |.5 |.5 |.5 |.5 |.5 |
        0.0___ |_|_|_|_|_|_|_|_|_|_|_|_|
        1.0___[| 0,| 1,| 2,| 3,| 4,| 5 ]
    /|\ 2.0___[|10,┃11,┃12,┃13,|14,|15 ]
    \|/ 3.0___[|20,┃21,┃22,┃23,|24,|25 ]
        4.0___[|30,|31,|32,|33,|34,|35 ]
        5.0___[|40,|41,|42,|43,|44,|45 ]
        6.0___[|50,|51,|52,|53,|54,|55 ]

Active region of interest = [[1.0, 1.0, 3.0, 3.0]] // a 2x2 window over the input elements
Input tensor window = [[11,12],[21,22]]
Output tensor size = [4,4]
Image Source Output 4x4, from first 2x2 region
image ✔ FB Research Detectron 2 (MaskedRCNN paper) [ 8.25,  8.75,  9.25,  9.75],
[13.25, 13.75, 14.25, 14.75],
[18.25, 18.75, 19.25, 19.75],
[23.25, 23.75, 24.25, 24.75]
image ✔ ONNX Runtime DML EP (ROI_ALIGN 0) [ 8.25,  8.75,  9.25,  9.75],
[13.25, 13.75, 14.25, 14.75],
[18.25, 18.75, 19.25, 19.75],
[23.25, 23.75, 24.25, 24.75]
image ✔ ONNX Runtime 1.7 CPU Resize + Slice
coordinate_transformation_mode=half
[ 8.25,  8.75,  9.25,  9.75],
[13.25, 13.75, 14.25, 14.75],
[18.25, 18.75, 19.25, 19.75],
[23.25, 23.75, 24.25, 24.75]
image torchvision.ops.roi_align(aligned=True…) [ 8.25,  8.75,  9.25,  9.75],
[13.25, 13.75, 14.25, 14.75],
[18.25, 18.75, 19.25, 19.75],
[23.25, 23.75, 24.25, 24.75]
image torchvision.ops.roi_align(aligned=False…)
*deprecated, legacy flag still exists
[13.75, 14.25, 14.75, 15.25],
[18.75, 19.25, 19.75, 20.25],
[23.75, 24.25, 24.75, 25.25],
[28.75, 29.25, 29.75, 30.25]
image ONNX Runtime 1.7 CPU EP RoiAlign [13.75, 14.25, 14.75, 15.25],
[18.75, 19.25, 19.75, 20.25],
[23.75, 24.25, 24.75, 25.25],
[28.75, 29.25, 29.75, 30.25]
image tf.image.crop_and_resize(…)
*Note boxes are normalized 0 to 1 (so /5 each ROI element)
[11.00, 11.66, 12.33, 13.00],
[17.66, 18.33, 19.00, 19.66],
[24.33, 25.00, 25.66, 26.33],
[31.00, 31.66, 32.33, 33.00]
image tf.image.resize_bilinear(align_corners=True…)
+ tf.slice
[11.00, 11.66, 12.33, 13.00],
[17.66, 18.33, 19.00, 19.66],
[24.33, 25.00, 25.66, 26.33],
[31.00, 31.66, 32.33, 33.00]
image tf.image.resize_bilinear(align_corners=False…)
+ tf.slice
[11.00, 11.50, 12.00, 12.50],
[16.00, 16.50, 17.00, 17.50],
[21.00, 21.50, 22.00, 22.50],
[26.00, 26.50, 27.00, 27.50]
image tf.image.resize_bilinear(half_pixel_centers=True…)
+ tf.slice
[ 8.25,  8.75,  9.25,  9.75],
[13.25, 13.75, 14.25, 14.75],
[18.25, 18.75, 19.25, 19.75],
[23.25, 23.75, 24.25, 24.75]
(todo) torch.nn.functional.interpolate
tf.keras.layers.UpSampling2D
(todo)

Even the ONNX backend conformance test case has these misaligned numbers: https://github.com/onnx/onnx/blob/master/onnx/backend/test/case/node/roialign.py


PyTorch sample code:

# pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
import torch
import torchvision
print("PyTorch version:", torch.__version__)

input = [[[[ 0, 1, 2, 3, 4, 5], # NCHW
            [10,11,12,13,14,15],
            [20,21,22,23,24,25],
            [30,31,32,33,34,35],
            [40,41,42,43,44,45],
            [50,51,52,53,54,55]]]]
boxes = [[0, 1,1,3,3]]
output_size = [4,4]
aligned=True # Correct
#aligned=False # Legacy setting
sampling_ratio=1
spatial_scale=1

# https://pytorch.org/vision/0.8/_modules/torchvision/ops/roi_align.html
output = torchvision.ops.roi_align(
    torch.tensor(input, dtype=torch.float),
    torch.tensor(boxes, dtype=torch.float),
    output_size,
    spatial_scale=spatial_scale,
    sampling_ratio=sampling_ratio,
    aligned=aligned
)

torch.set_printoptions(sci_mode=False)
print(input)
print(boxes)
print(output)

TensorFlow sample code:

# pip install tensorflow-gpu==1.15.0
import os
import tensorflow.compat.v1 as tf

input = [[ # NHWC
            [[ 0.], [ 1.], [ 2.], [ 3.], [ 4.], [ 5.]],
            [[10.], [11.], [12.], [13.], [14.], [15.]],
            [[20.], [21.], [22.], [23.], [24.], [25.]],
            [[30.], [31.], [32.], [33.], [34.], [35.]],
            [[40.], [41.], [42.], [43.], [44.], [45.]],
            [[50.], [51.], [52.], [53.], [54.], [55.]]
        ]]
boxes = [[1/5,1/5,3/5,3/5],[3/5,3/5,4/5,4/5]] # Normalized 0.0 to 1.0 (where 1.0 = width - 1 and height - 1)
box_indices = [0, 0] # Batch indices per corresponding region
crop_size = [4, 4] # Output tensor size HW

print("TensorFlow version:", tf.__version__) # 1.15.0 (cpu/cuda)

# Using half_pixel_centers=True is correct (not align_corners=True)
output_size = [6*2, 6*2]
resize_output = tf.image.resize_bilinear(tf.constant(input), output_size, align_corners=False, half_pixel_centers=True)
resize_bilinear_slice_output = tf.slice(resize_output, [0,2,2,0], [1,4,4,1])

# Note crop_and_resize doesn't scale the image boundaries to pixel centers, but always to corners,
# and there is sadly no flag to influence this (unlike resize_bilinear).
method = 'bilinear'
extrapolation_value = 0
crop_and_resize_output = tf.image.crop_and_resize(
    image=tf.constant(input, dtype=tf.float32), # NHWC
    boxes=tf.constant(boxes, dtype=tf.float32),
    box_ind=tf.constant(box_indices, dtype=tf.int32),
    crop_size=tf.constant(crop_size, dtype=tf.int32),
    method=method,
    extrapolation_value=extrapolation_value
)

with tf.Session(config=config) as session:
    with np.printoptions(precision=3, suppress=True):
        print("input:\n", input)
        print("crop_and_resize:\n", session.run(crop_and_resize_output))
        print("resize_bilinear_and_slice:\n", session.run(resize_bilinear_slice_output))

Facebook research's Detectron 2 test code:

class ROIAlignTest(unittest.TestCase):
    def test_forward_output(self):
        input = np.arange(25).reshape(5, 5).astype("float32")
        """
        0  1  2   3 4
        5  6  7   8 9
        10 11 12 13 14
        15 16 17 18 19
        20 21 22 23 24
        """

        output = self._simple_roialign(input, [1, 1, 3, 3], (4, 4), aligned=False)
        output_correct = self._simple_roialign(input, [1, 1, 3, 3], (4, 4), aligned=True)

        # without correction:
        old_results = [
            [7.5, 8, 8.5, 9],
            [10, 10.5, 11, 11.5],
            [12.5, 13, 13.5, 14],
            [15, 15.5, 16, 16.5],
        ]

        # with 0.5 correction:
        correct_results = [
            [4.5, 5.0, 5.5, 6.0],
            [7.0, 7.5, 8.0, 8.5],
            [9.5, 10.0, 10.5, 11.0],
            [12.0, 12.5, 13.0, 13.5],
        ]
        # This is an upsampled version of [[6, 7], [11, 12]]
...
@hariharans29
Copy link
Member

I will investigate this

@fdwr
Copy link
Contributor Author

fdwr commented Mar 9, 2021

Thanks Hari. Note that the coordinates of RoiAlign are on an infinite floating point grid, unlike RoiPool (with integers and that weirdo -1 size adjustment). So a ROI (x1,y1,x2,y2) of [1.6, 1.3, 3.3, 2.75] means a region size of width=1.7 (3.3 - 1.6) and height=1.45 (2.75 - 1.3).

image

Added some diagrams in the above table, with blue as the input image and orange as the regions to write to output (the dots are the blue input sample points and orange output points).

@hariharans29
Copy link
Member

hariharans29 commented Mar 15, 2021

Hi Dwayne,

I investigated this and here are my findings:

  1. In your table above, this doesn't seem right:
  torchvision.ops.roi_align(aligned=False…)*deprecated, legacy flag still exists [13.75, 14.25, 14.75, 15.25],[18.75, 19.25, 19.75, 20.25],[23.75, 24.25, 24.75, 25.25],[28.75, 29.25, 29.75, 30.25]
??? 🙃 ONNX Runtime 1.7 CPU EP RoiAlign [6.1875, 6.75, 6.75, 7.3125],[11.8125, 12.375, 12.375, 12.9375],[11.8125, 12.375, 12.375, 12.9375],[17.4375, 18, 18, 18.5625]

The reason is that these results are produced by ORT when the pooling mode is 'max' for the operator. From the looks of it, Torch only seems to support 'avg' pooling. As soon as you change the mode to 'avg', you will see that the results from ORT are matching TorchVision (aligned = False).

So, my conclusion is ORT's CPU implementation (in avg mode) == Legacy Torch ROIAlign (aligned = False)

  1. To ensure backwards compatibility, TorchVision seems to have introduced the aligned flag and gives the user the option of "fixing the misalignment"(the legacy aligned = False is still the default mode though). In the Detectron project, the RoiAlign wrapper they have has the default value for the 'aligned' flag set to True and this nuance is seemingly the root cause for the diffs wrt to ORT. So, to summarize, ORT's CPU backend implements the legacy mode and DirectML backend (like Detectron) is implementing the "misalignment fixed" logic. Does this make sense?

Unfortunately, I don't know how to fix this without breaking backwards compatibility. The way I see it, we may have to introduce a new attribute in ONNX for this op (just like TorchVision introduced the aligned flag) that would allow the user to pick which implementation they would like. What do you think ?

@fdwr
Copy link
Contributor Author

fdwr commented Mar 16, 2021

I fixed the table above and wording showing it matches deprecated PyTorch behavior (oops, sorry for the "max" vs "avg" mixup when I recorded that result 😅). Yeah, back compat is a concern. Really ONNX should add an attribute to RoiAlign like it did with Resize and default to half_pixels (ONNX convertors from old opsets can set it to the legacy behavior). Thanks for investigating.

I'll plot the output rectangles from both approaches over the original image to contrast them. I've opened an ONNX issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants