-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue in refine_bboxes and add doctest #1962
Changes from 2 commits
f0dba7b
bfc1248
4b1328c
e59d3cb
159bbf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
import torch | ||
|
||
|
||
def ensure_rng(rng=None): | ||
""" | ||
Simple version of the ``kwarray.ensure_rng`` | ||
|
||
Args: | ||
rng (int | numpy.random.RandomState | None): | ||
if None, then defaults to the global rng. Otherwise this can be an | ||
integer or a RandomState class | ||
Returns: | ||
(numpy.random.RandomState) : rng - | ||
a numpy random number generator | ||
|
||
References: | ||
https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 | ||
""" | ||
|
||
if rng is None: | ||
rng = np.random.mtrand._rand | ||
elif isinstance(rng, int): | ||
rng = np.random.RandomState(rng) | ||
else: | ||
rng = rng | ||
return rng | ||
|
||
|
||
def random_boxes(num=1, scale=1, rng=None): | ||
""" | ||
Simple version of ``kwimage.Boxes.random`` | ||
|
||
Returns: | ||
Tensor: shape (n, 4) in x1, y1, x2, y2 format. | ||
|
||
References: | ||
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390 | ||
|
||
Example: | ||
>>> num = 10 | ||
>>> scale = 512 | ||
>>> rng = 0 | ||
>>> boxes = random_boxes() | ||
""" | ||
rng = ensure_rng(rng) | ||
|
||
tlbr = rng.rand(num, 4).astype(np.float32) | ||
|
||
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2]) | ||
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3]) | ||
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2]) | ||
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3]) | ||
|
||
tlbr[:, 0] = tl_x * scale | ||
tlbr[:, 1] = tl_y * scale | ||
tlbr[:, 2] = br_x * scale | ||
tlbr[:, 3] = br_y * scale | ||
|
||
boxes = torch.FloatTensor(tlbr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are they not equivalent in this case? The numpy array is already a float32, so there shouldn't be any difference in calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to know, thanks for the info! |
||
return boxes |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,7 +178,8 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): | |
|
||
Args: | ||
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, | ||
and bs is the sampled RoIs per image. | ||
and bs is the sampled RoIs per image. The first column is | ||
the image id and the next 4 columns are x1, y1, x2, y2. | ||
labels (Tensor): Shape (n*bs, ). | ||
bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class). | ||
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox | ||
|
@@ -187,13 +188,45 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): | |
|
||
Returns: | ||
list[Tensor]: Refined bboxes of each image in a mini-batch. | ||
|
||
Example: | ||
>>> # xdoctest: +REQUIRES(module:kwarray) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a temporary solution, |
||
>>> import kwarray | ||
>>> self = BBoxHead(reg_class_agnostic=True) | ||
>>> n_roi = 2 | ||
>>> n_img = 4 | ||
>>> scale = 512 | ||
>>> rng = np.random.RandomState(0) | ||
>>> img_metas = [{'img_shape': (scale, scale)} | ||
... for _ in range(n_img)] | ||
>>> # Create rois in the expected format | ||
>>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) | ||
>>> img_ids = torch.randint(0, n_img, (n_roi,)) | ||
>>> img_ids = img_ids | ||
>>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1) | ||
>>> # Create other args | ||
>>> labels = torch.randint(0, 2, (n_roi,)).long() | ||
>>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) | ||
>>> # For each image, | ||
>>> is_label_pos = (labels.numpy() > 0).astype(np.int) | ||
>>> lbl_per_img = kwarray.group_items(is_label_pos, | ||
... img_ids.numpy()) | ||
>>> pos_per_img = [sum(lbl_per_img.get(gid, [])) | ||
... for gid in range(n_img)] | ||
>>> pos_is_gts = [ | ||
>>> torch.randint(0, 2, (npos,)).byte().sort( | ||
>>> descending=True)[0] | ||
>>> for npos in pos_per_img | ||
>>> ] | ||
>>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds, | ||
>>> pos_is_gts, img_metas) | ||
""" | ||
img_ids = rois[:, 0].long().unique(sorted=True) | ||
assert img_ids.numel() == len(img_metas) | ||
assert img_ids.numel() <= len(img_metas) | ||
|
||
bboxes_list = [] | ||
for i in range(len(img_metas)): | ||
inds = torch.nonzero(rois[:, 0] == i).squeeze() | ||
inds = torch.nonzero(rois[:, 0] == i).squeeze(dim=1) | ||
num_rois = inds.numel() | ||
|
||
bboxes_ = rois[inds, 1:] | ||
|
@@ -204,6 +237,7 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): | |
|
||
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, | ||
img_meta_) | ||
|
||
# filter gt bboxes | ||
pos_keep = 1 - pos_is_gts_ | ||
keep_inds = pos_is_gts_.new_ones(num_rois) | ||
|
@@ -226,7 +260,7 @@ def regress_by_class(self, rois, label, bbox_pred, img_meta): | |
Returns: | ||
Tensor: Regressed bboxes, the same shape as input rois. | ||
""" | ||
assert rois.size(1) == 4 or rois.size(1) == 5 | ||
assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape) | ||
|
||
if not self.reg_class_agnostic: | ||
label = label * 4 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intend
random(num, scale, rng)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did yes. Fixing.