Skip to content

Commit

Permalink
Fix device slices for Buffer with fixed dimensionality in template.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcourteaux committed Jun 24, 2024
1 parent 5f6fc26 commit 2079057
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/runtime/HalideBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,20 @@ class Buffer {
// Note that this is called "cropped" but can also encompass a slice/embed
// operation as well.
struct DevRefCountCropped : DeviceRefCount {
Buffer<T, Dims, InClassDimStorage> cropped_from;
explicit DevRefCountCropped(const Buffer<T, Dims, InClassDimStorage> &cropped_from)
// We will store a fixed number of 6 dimensions in the class storage, because we can also slice
// from Buffers, which have higher dimensions. As we cannot possibly know which is the
// dimensionality of the original Buffer, we have to go with a fixed number for the worst-case
// scenario. Unfortunately, although 4 is the default if the dimensionality is variable (AnyDims),
// the user can still specify a higher number of dimensions.
Buffer<T, AnyDims, 6> cropped_from;
explicit DevRefCountCropped(const Buffer<T, AnyDims, 6> &cropped_from)
: cropped_from(cropped_from) {
ownership = BufferDeviceOwnership::Cropped;
}
};

/** Setup the device ref count for a buffer to indicate it is a crop (or slice, embed, etc) of cropped_from */
void crop_from(const Buffer<T, Dims, InClassDimStorage> &cropped_from) {
void crop_from(const Buffer<T, AnyDims, 6> &cropped_from) {
assert(dev_ref_count == nullptr);
dev_ref_count = new DevRefCountCropped(cropped_from);
}
Expand Down Expand Up @@ -513,15 +518,15 @@ class Buffer {
void complete_device_crop(Buffer<T, Dims, InClassDimStorage> &result_host_cropped) const {
assert(buf.device_interface != nullptr);
if (buf.device_interface->device_crop(nullptr, &this->buf, &result_host_cropped.buf) == halide_error_code_success) {
const Buffer<T, Dims, InClassDimStorage> *cropped_from = this;
// TODO: Figure out what to do if dev_ref_count is nullptr. Should incref logic run here?
// is it possible to get to this point without incref having run at least once since
// the device field was set? (I.e. in the internal logic of crop. incref might have been
// called.)
if (dev_ref_count != nullptr && dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
cropped_from = &((DevRefCountCropped *)dev_ref_count)->cropped_from;
result_host_cropped.crop_from(((DevRefCountCropped *)dev_ref_count)->cropped_from);
} else {
result_host_cropped.crop_from(*this);
}
result_host_cropped.crop_from(*cropped_from);
}
}

Expand All @@ -545,16 +550,17 @@ class Buffer {
void complete_device_slice(Buffer<T, AnyDims, InClassDimStorage> &result_host_sliced, int d, int pos) const {
assert(buf.device_interface != nullptr);
if (buf.device_interface->device_slice(nullptr, &this->buf, d, pos, &result_host_sliced.buf) == halide_error_code_success) {
const Buffer<T, Dims, InClassDimStorage> *sliced_from = this;
// TODO: Figure out what to do if dev_ref_count is nullptr. Should incref logic run here?
// is it possible to get to this point without incref having run at least once since
// the device field was set? (I.e. in the internal logic of slice. incref might have been
// called.)
if (dev_ref_count != nullptr && dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
sliced_from = &((DevRefCountCropped *)dev_ref_count)->cropped_from;
// crop_from() is correct here, despite the fact that we are slicing.
result_host_sliced.crop_from(((DevRefCountCropped *)dev_ref_count)->cropped_from);
} else {
// crop_from() is correct here, despite the fact that we are slicing.
result_host_sliced.crop_from(*this);
}
// crop_from() is correct here, despite the fact that we are slicing.
result_host_sliced.crop_from(*sliced_from);
}
}

Expand Down
26 changes: 26 additions & 0 deletions test/correctness/device_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,32 @@ int main(int argc, char **argv) {
});
}

printf("Test nondestructive slicing with given dimensions.\n");
{
Halide::Runtime::Buffer<int32_t, 3> gpu_buf = make_gpu_buffer(hexagon_rpc);
assert(gpu_buf.raw_buffer()->device_interface != nullptr);

const int slice_dim = 0;
const int slice_pos = 31;
Halide::Runtime::Buffer<int32_t, 2> sliced = gpu_buf.sliced(slice_dim, slice_pos);
assert(sliced.raw_buffer()->device_interface != nullptr);

assert(sliced.dimensions() == 2);
assert(sliced.extent(0) == kEdges[1]);
assert(sliced.extent(1) == kEdges[2]);

sliced.copy_to_host();
sliced.for_each_element([&](int y, int c) {
const int x = slice_pos;
assert(sliced(y, c) == x + y * 256 + c * 256 * 256);
});

gpu_buf.copy_to_host();
gpu_buf.for_each_element([&](int x, int y, int c) {
assert(gpu_buf(x, y, c) == x + y * 256 + c * 256 * 256);
});
}

printf("Test slice of a slice\n");
{
Halide::Runtime::Buffer<int32_t> gpu_buf = make_gpu_buffer(hexagon_rpc);
Expand Down

0 comments on commit 2079057

Please sign in to comment.