Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU MultiPaste #2681

Merged
merged 34 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
67772b3
First version for multipaste
TheTimemaster Dec 21, 2020
7c12a8e
Actually run the copy
TheTimemaster Dec 21, 2020
6e2843f
Implementation is now working. Fixed documentation comments
TheTimemaster Jan 3, 2021
111993c
Merge branch 'master' of github.com:NVIDIA/DALI
TheTimemaster Jan 7, 2021
cd6d288
Removed BatchIndex operator, Multipaste out_idx is now optional. Fixe…
TheTimemaster Jan 7, 2021
c2e67d1
Added python tests
TheTimemaster Jan 12, 2021
9a3b3f4
Most inputs are now named arguments. Output_width+height changed into…
TheTimemaster Jan 12, 2021
3ba0a5f
Removed redundant 'input_out_ids' argument
TheTimemaster Jan 12, 2021
6b3574e
Merge branch 'master' of github.com:NVIDIA/DALI
TheTimemaster Jan 15, 2021
6ba81d0
Merge branch 'master' of github.com:NVIDIA/DALI
TheTimemaster Jan 16, 2021
58419ff
Shapes and anchors are now optional. Output is now zero-ed before pas…
TheTimemaster Jan 18, 2021
fc16cc6
In python tests, validation is no longer done pixel by pixel
TheTimemaster Jan 18, 2021
73b048a
Fixed issues with optional parameters
TheTimemaster Jan 19, 2021
84031c1
Fixed issues from review
TheTimemaster Jan 27, 2021
6ac9b57
Merge branch 'master' of github.com:NVIDIA/DALI
TheTimemaster Jan 31, 2021
16231c6
Changed syntax for setting the output size
TheTimemaster Jan 31, 2021
ce70740
Extracted large code blocks from TYPE_SWITCH to separate functions. R…
TheTimemaster Feb 11, 2021
8fea944
Changed tab size in RunImplExplicitlyTyped
TheTimemaster Feb 13, 2021
02f8ebf
Tab fix
TheTimemaster Feb 13, 2021
c1b3bce
Tab fix
TheTimemaster Feb 13, 2021
c7d0fde
First version
TheTimemaster Feb 16, 2021
28a90a1
Python test fix
TheTimemaster Feb 16, 2021
1640dfc
Test parenthesis fix
TheTimemaster Feb 17, 2021
7c23f0a
Merge branch 'master' into gpu-multipaste
TheTimemaster Feb 17, 2021
e183ab5
WIP
TheTimemaster Feb 23, 2021
1ed27b9
Merge branch 'master' of github.com:NVIDIA/DALI
TheTimemaster Feb 23, 2021
c03f1b0
Merge
TheTimemaster Feb 23, 2021
70ba2f2
First semi-working version
TheTimemaster Feb 24, 2021
c9a1b4d
Fixed kernel glitch and test file
TheTimemaster Feb 25, 2021
43b423c
Improved kernel performance
TheTimemaster Feb 25, 2021
1d33ba3
Patches are now generated inside kernel
TheTimemaster Feb 28, 2021
1491833
Changed SampleInput structure
TheTimemaster Mar 1, 2021
2c05e16
Operatror now throws if pastes are out of bounds
TheTimemaster Mar 3, 2021
9e04585
Fixed signed/unsigned comparisons
TheTimemaster Mar 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions dali/kernels/imgproc/paste/paste_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void FillOutDetails(span<SampleDescriptorGPU<OutputType, InputType, ndims - 1>>
template <class OutputType, class InputType, int ndims>
void CreateSampleDescriptors(
vector<SampleDescriptorGPU<OutputType, InputType, ndims - 1>> &out_descs,
vector<PatchDesc<InputType, ndims - 1>> &out_patchs,
vector<PatchDesc<InputType, ndims - 1>> &out_patches,
const InListGPU<InputType, ndims> &in,
span<paste::MultiPasteSampleInput<ndims - 1>> samples
) {
Expand All @@ -75,7 +75,7 @@ void CreateSampleDescriptors(
int batch_size = samples.size();

out_descs.resize(batch_size);
out_patchs.clear();
out_patches.clear();

for (int out_idx = 0; out_idx < batch_size; out_idx++) {
const auto &sample = samples[out_idx];
Expand Down Expand Up @@ -130,15 +130,15 @@ void CreateSampleDescriptors(
y_ending[y_patch_cnt].emplace_back(-1, 0, x_patch_cnt);

// Filling sample
int prev_patch_count = out_patchs.size();
int prev_patch_count = out_patches.size();
auto &out_sample = out_descs[out_idx];
out_sample.patch_start_idx = prev_patch_count;
out_sample.patch_counts[1] = x_patch_cnt;
out_sample.patch_counts[0] = y_patch_cnt;
out_sample.out_pitch[1] = sample.out_size[1] * channels;

// And now the sweeping itself
out_patchs.resize(prev_patch_count + x_patch_cnt * y_patch_cnt);
out_patches.resize(prev_patch_count + x_patch_cnt * y_patch_cnt);
vector<std::unordered_set<int>> starting(x_patch_cnt + 1);
vector<std::unordered_set<int>> ending(x_patch_cnt + 1);
std::set<int> open_pastes;
Expand All @@ -157,7 +157,7 @@ void CreateSampleDescriptors(

// Take top most region
int max_paste = *(--open_pastes.end());
auto& patch = out_patchs[prev_patch_count + y * x_patch_cnt + x];
auto& patch = out_patches[prev_patch_count + y * x_patch_cnt + x];


// And fill the patch
Expand Down Expand Up @@ -198,12 +198,12 @@ void CreateSampleDescriptors(
template <class OutputType, class InputType, int ndims>
__global__ void
PasteKernel(const SampleDescriptorGPU<OutputType, InputType, ndims> *samples,
const PatchDesc<InputType, ndims> *patchs,
const PatchDesc<InputType, ndims> *patches,
const BlockDesc<ndims> *blocks) {
static_assert(ndims == 2, "Function requires 2 dimensions in the input");
const auto &block = blocks[blockIdx.x];
const auto &sample = samples[block.sample_idx];
const PatchDesc<InputType, ndims> *my_patches = patchs + sample.patch_start_idx;
const PatchDesc<InputType, ndims> *my_patches = patches + sample.patch_start_idx;


auto *__restrict__ out = sample.out;
Expand Down
13 changes: 10 additions & 3 deletions dali/operators/image/paste/multipaste.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,19 @@ class MultiPasteOp : public Operator<Backend> {
}

bool found_intersection = false;

for (int j = 0; j < n_paste; j++) {
auto out_anchor = GetOutAnchors(i, j);
auto in_anchor = GetInAnchors(i, j);
auto j_idx = in_idx_[i].data[j];
const auto &shape = GetShape(i, j, Coords(raw_input_size_mem_.data() + 2 * j_idx,
dali::TensorShape<>(2)));
auto in_shape = Coords(raw_input_size_mem_.data() + 2 * j_idx, dali::TensorShape<>(2));
const auto &shape = GetShape(i, j, in_shape);
for (int k = 0; k < spatial_ndim; k++) {
DALI_ENFORCE(out_anchor.data[k] >= 0 && in_anchor.data[k] >= 0 &&
out_anchor.data[k] + shape.data[k] <= output_size_[i].data[k] &&
in_anchor.data[k] + shape.data[k] <= in_shape.data[k],
"Paste in/out coords should be within inout/output bounds.");
}

for (int k = 0; k < j; k++) {
auto k_idx = in_idx_[i].data[k];
if (Intersects(out_anchor, shape, GetOutAnchors(i, k), GetShape(
Expand Down
94 changes: 69 additions & 25 deletions dali/test/python/test_operator_multipaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def prepare_cuts(
full_input=False,
in_anchor_top_left=False,
out_anchor_top_left=False,
out_of_bounds_count=0,
):
# Those two will not work together
assert(out_of_bounds_count == 0 or not no_intersections)

in_idx_l = [np.zeros(shape=(0,), dtype=np.int32) for _ in range(batch_size)]
in_anchors_l = [np.zeros(shape=(0, 2), dtype=np.int32) for _ in range(batch_size)]
shapes_l = [np.zeros(shape=(0, 2), dtype=np.int32) for _ in range(batch_size)]
Expand Down Expand Up @@ -104,6 +108,25 @@ def prepare_cuts(
in_anchors_l[out_idx] = np.append(in_anchors_l[out_idx], [in_anchor], axis=0)
shapes_l[out_idx] = np.append(shapes_l[out_idx], [shape], axis=0)
out_anchors_l[out_idx] = np.append(out_anchors_l[out_idx], [out_anchor], axis=0)
for i in range(out_of_bounds_count):
clip_out_idx = np.random.randint(batch_size)
while len(in_idx_l[clip_out_idx]) == 0:
clip_out_idx = np.random.randint(batch_size)
clip_in_idx = np.random.randint(len(in_idx_l[clip_out_idx]))
change_in = np.random.randint(2) == 0
below_zero = np.random.randint(2) == 0
change_dim_idx = np.random.randint(dim)
if below_zero:
(in_anchors_l if change_in else out_anchors_l)[clip_out_idx][clip_in_idx][change_dim_idx] = \
np.int32(np.random.randint(5) - 5)
else:
(in_anchors_l if change_in else out_anchors_l)[clip_out_idx][clip_in_idx][change_dim_idx] = \
np.int32(
(input_size if change_in else output_size)[change_dim_idx] -
shapes_l[clip_out_idx][clip_in_idx][change_dim_idx] +
np.random.randint(5) + 1
)

return in_idx_l, in_anchors_l, shapes_l, out_anchors_l


Expand All @@ -118,7 +141,8 @@ def get_pipeline(
full_input=False,
in_anchor_top_left=False,
out_anchor_top_left=False,
use_gpu=False
use_gpu=False,
num_out_of_bounds=0
):
pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=0)
with pipe:
Expand All @@ -127,7 +151,7 @@ def get_pipeline(
resized = fn.resize(decoded, resize_x=in_size[1], resize_y=in_size[0])
in_idx_l, in_anchors_l, shapes_l, out_anchors_l = prepare_cuts(
k, batch_size, in_size, out_size, even_paste_count,
no_intersections, full_input, in_anchor_top_left, out_anchor_top_left)
no_intersections, full_input, in_anchor_top_left, out_anchor_top_left, num_out_of_bounds)
in_idx = fn.external_source(lambda: in_idx_l)
in_anchors = fn.external_source(lambda: in_anchors_l)
shapes = fn.external_source(lambda: shapes_l)
Expand All @@ -147,14 +171,24 @@ def get_pipeline(
if not out_anchor_top_left:
kwargs["out_anchors"] = out_anchors

if use_gpu:
kwargs["device"] = "gpu"

pasted = fn.multi_paste(resized.gpu() if use_gpu else resized, **kwargs)
pipe.set_outputs(pasted, resized)
return pipe, in_idx_l, in_anchors_l, shapes_l, out_anchors_l


def verify_out_of_bounds(batch_size, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, in_size, out_size):
for i in range(batch_size):
for j, idx in enumerate(in_idx_l[i]):
dim = len(in_anchors_l[i][j])
for d in range(dim):
if in_anchors_l[i][j][d] < 0 or out_anchors_l[i][j][d] < 0 or \
in_anchors_l[i][j][d] + shapes_l[i][j][d] > in_size[d] or \
out_anchors_l[i][j][d] + shapes_l[i][j][d] > out_size[d]:
return True
return False



def manual_verify(batch_size, inp, output, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, out_size_l, dtype):
for i in range(batch_size):
out = output.at(i)
Expand Down Expand Up @@ -189,7 +223,7 @@ def show_images(batch_size, image_batch):


def check_operator_multipaste(bs, pastes, in_size, out_size, even_paste_count, no_intersections, full_input, in_anchor_top_left,
out_anchor_top_left, out_dtype, device):
out_anchor_top_left, out_dtype, num_out_of_bounds, device):
pipe, in_idx_l, in_anchors_l, shapes_l, out_anchors_l = get_pipeline(
batch_size=bs,
in_size=in_size,
Expand All @@ -201,15 +235,22 @@ def check_operator_multipaste(bs, pastes, in_size, out_size, even_paste_count, n
full_input=full_input,
in_anchor_top_left=in_anchor_top_left,
out_anchor_top_left=out_anchor_top_left,
num_out_of_bounds=num_out_of_bounds,
use_gpu=device == 'gpu'
)
pipe.build()
result, input = pipe.run()
r = result.as_cpu() if device == 'gpu' else result
if SHOW_IMAGES:
show_images(bs, r)
manual_verify(bs, input, r, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, [out_size + (3,)] * bs, out_dtype)

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nose has assert_raises for this kind of test. Maybe it would be good to split it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the assert_raises method, but it only asserts the type (RuntimeError) of the exception, and any DALI_FAIL results in this type

result, input = pipe.run()
r = result.as_cpu() if device == 'gpu' else result
if SHOW_IMAGES:
show_images(bs, r)
assert not verify_out_of_bounds(bs, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, in_size, out_size)
manual_verify(bs, input, r, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, [out_size + (3,)] * bs, out_dtype)
except RuntimeError as e:
if "Paste in/out coords should be within inout/output bounds" in str(e):
assert verify_out_of_bounds(bs, in_idx_l, in_anchors_l, shapes_l, out_anchors_l, in_size, out_size)
else:
assert False

def test_operator_multipaste():
tests = [
Expand All @@ -225,19 +266,22 @@ def test_operator_multipaste():
# - should "out_anchors" parameter be omitted
# - output dtype
# - backend
[4, 2, (128, 256), (128, 128), False, False, False, False, False, types.UINT8],
[4, 2, (256, 128), (128, 128), False, True, False, False, False, types.UINT8],
[4, 2, (128, 128), (256, 128), True, False, False, False, False, types.UINT8],
[4, 2, (128, 128), (128, 256), True, True, False, False, False, types.UINT8],

[4, 2, (64, 64), (128, 128), False, False, True, False, False, types.UINT8],
[4, 2, (64, 64), (128, 128), False, False, False, True, False, types.UINT8],
[4, 2, (64, 64), (128, 128), False, False, False, False, True, types.UINT8],

[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.UINT8],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.INT16],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.INT32],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.FLOAT],
# - number of out-of-bounds anchor changes
[4, 2, (128, 256), (128, 128), False, False, False, False, False, types.UINT8, 0],
[4, 2, (256, 128), (128, 128), False, True, False, False, False, types.UINT8, 0],
[4, 2, (128, 128), (256, 128), True, False, False, False, False, types.UINT8, 0],
[4, 2, (128, 128), (128, 256), True, True, False, False, False, types.UINT8, 0],

[4, 2, (64, 64), (128, 128), False, False, True, False, False, types.UINT8, 0],
[4, 2, (64, 64), (128, 128), False, False, False, True, False, types.UINT8, 0],
[4, 2, (64, 64), (128, 128), False, False, False, False, True, types.UINT8, 0],

[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.UINT8, 0],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.INT16, 0],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.INT32, 0],
[4, 2, (128, 128), (128, 128), False, False, False, False, False, types.FLOAT, 0],

[4, 2, (128, 256), (128, 128), False, False, False, False, False, types.UINT8, 4],
]
for t in tests:
yield (check_operator_multipaste, *t, "cpu")
Expand Down