Skip to content

Commit

Permalink
PERF: Optimize FullSampler for if input image and mask have same domain
Browse files Browse the repository at this point in the history
Similar to the existing optimization in itkComputeImageExtremaFilter.hxx

When the mask and the image have exactly the same image domain, the sampler may become more than 3x as fast as before, from more than 0.60 sec. (before this commit) to less than 0.18 (after this commit), on an image of 4096x4096 pixels, using VS2022 Release.
  • Loading branch information
N-Dekker committed Feb 7, 2024
1 parent 50db326 commit b7029ed
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 28 deletions.
74 changes: 73 additions & 1 deletion Common/GTesting/itkImageFullSamplerGTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
#include <itkImage.h>
#include <itkImageMaskSpatialObject.h>
#include <gtest/gtest.h>

#include <cmath> // For nextafter.

using elx::CoreMainGTestUtilities::CreateImage;
using elx::CoreMainGTestUtilities::CreateImageFilledWithSequenceOfNaturalNumbers;
using elx::CoreMainGTestUtilities::CreateRandomImageDomain;
using elx::CoreMainGTestUtilities::DerefRawPointer;
using elx::CoreMainGTestUtilities::GenerateRandomSign;
using elx::CoreMainGTestUtilities::ImageDomain;
using elx::CoreMainGTestUtilities::minimumImageSizeValue;


GTEST_TEST(ImageFullSampler, OutputHasSameSequenceOfPixelValuesAsInput)
Expand Down Expand Up @@ -119,4 +121,74 @@ GTEST_TEST(ImageFullSampler, HasSameOutputWhenUsingFullyFilledMask)
};

EXPECT_EQ(generateSamples(true), generateSamples(false));
}


// Tests that the sampler produces the same output when using a mask whose domain is exactly equal to the image domain
// as when the domains are only slightly different.
GTEST_TEST(ImageFullSampler, ExactlyEqualVersusSlightlyDifferentMaskImageDomain)
{
using PixelType = int;
enum
{
Dimension = 2U
};
using SamplerType = itk::ImageFullSampler<itk::Image<PixelType, Dimension>>;

std::mt19937 randomNumberEngine{};
const auto image =
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(CreateRandomImageDomain<Dimension>(randomNumberEngine));

const auto generateSamples = [image](const bool exactlyEqualImageDomain) {
elx::DefaultConstruct<SamplerType> sampler{};
sampler.SetUseMultiThread(false);
sampler.SetInput(image);

using MaskSpatialObjectType = itk::ImageMaskSpatialObject<Dimension>;
const auto maskImage = CreateImage<MaskSpatialObjectType::PixelType>(ImageDomain(*image));

std::mt19937 randomNumberEngine{};

for (MaskSpatialObjectType::PixelType & maskPixel : itk::ImageBufferRange(*maskImage))
{
maskPixel = static_cast<MaskSpatialObjectType::PixelType>(randomNumberEngine() % 2);
}

if (!exactlyEqualImageDomain)
{
// Make the domain of the mask image slightly different by making very small changes to the origin and the
// spacing.
auto origin = image->GetOrigin();

for (double & value : origin)
{
value = std::nextafter(value, GenerateRandomSign(randomNumberEngine) * std::numeric_limits<double>::max());
}
maskImage->SetOrigin(origin);

auto spacing = image->GetSpacing();

for (double & value : spacing)
{
value = std::nextafter(value, GenerateRandomSign(randomNumberEngine) * std::numeric_limits<double>::max());
}
maskImage->SetSpacing(spacing);
}

const auto maskSpatialObject = MaskSpatialObjectType::New();
maskSpatialObject->SetImage(maskImage);
maskSpatialObject->Update();

sampler.SetMask(maskSpatialObject);
sampler.Update();
return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer());
};

const auto samplesOnExactlyEqualImageDomains = generateSamples(true);
const auto samplesOnSlightlyDifferentImageDomains = generateSamples(false);

// The test would be trivial (uninteresting) if there were no samples.
EXPECT_FALSE(samplesOnExactlyEqualImageDomains.empty());

EXPECT_EQ(samplesOnExactlyEqualImageDomains, samplesOnSlightlyDifferentImageDomains);
}
5 changes: 3 additions & 2 deletions Common/ImageSamplers/itkImageFullSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define itkImageFullSampler_h

#include "itkImageSamplerBase.h"
#include "elxMaskHasSameImageDomain.h"

namespace itk
{
Expand Down Expand Up @@ -132,7 +133,7 @@ class ITK_TEMPLATE_EXPORT ImageFullSampler : public ImageSamplerBase<TInputImage
std::vector<WorkUnit> WorkUnits{};
};

template <bool VUseMask>
template <elastix::MaskCondition VMaskCondition>
static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ThreaderCallback(void * arg);

Expand All @@ -156,7 +157,7 @@ class ITK_TEMPLATE_EXPORT ImageFullSampler : public ImageSamplerBase<TInputImage
std::vector<ImageSampleType> & samples);

/** Generates the data for one specific work unit. */
template <bool VUseMask>
template <elastix::MaskCondition VMaskCondition>
static void
GenerateDataForWorkUnit(WorkUnit &, const InputImageType &, const MaskType *, const WorldToObjectTransformType *);
};
Expand Down
76 changes: 51 additions & 25 deletions Common/ImageSamplers/itkImageFullSampler.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ ImageFullSampler<TInputImage>::SingleThreadedGenerateData(const TInputImage &

if (mask)
{
GenerateDataForWorkUnit<true>(workUnit, inputImage, mask, mask->GetObjectToWorldTransformInverse());
if (elastix::MaskHasSameImageDomain(*mask, inputImage))
{
GenerateDataForWorkUnit<elastix::MaskCondition::HasSameImageDomain>(workUnit, inputImage, mask, nullptr);
}
else
{
GenerateDataForWorkUnit<elastix::MaskCondition::HasDifferentImageDomain>(
workUnit, inputImage, mask, mask->GetObjectToWorldTransformInverse());
}

assert(workUnit.NumberOfSamples <= samples.size());
samples.resize(workUnit.NumberOfSamples);
}
else
{
GenerateDataForWorkUnit<false>(workUnit, inputImage, nullptr, nullptr);
GenerateDataForWorkUnit<elastix::MaskCondition::IsNull>(workUnit, inputImage, nullptr, nullptr);
}
}

Expand All @@ -101,12 +109,24 @@ ImageFullSampler<TInputImage>::MultiThreadedGenerateData(MultiThreaderBase &
{
samples.resize(croppedInputImageRegion.GetNumberOfPixels());

const bool maskHasSameImageDomain = mask ? elastix::MaskHasSameImageDomain(*mask, inputImage) : false;

UserData userData{ inputImage,
mask,
mask ? mask->GetObjectToWorldTransformInverse() : nullptr,
(mask == nullptr || maskHasSameImageDomain) ? nullptr : mask->GetObjectToWorldTransformInverse(),
GenerateWorkUnits(numberOfWorkUnits, croppedInputImageRegion, samples) };

multiThreader.SetSingleMethod(mask ? &Self::ThreaderCallback<true> : &Self::ThreaderCallback<false>, &userData);
if (mask)
{
multiThreader.SetSingleMethod(elastix::MaskHasSameImageDomain(*mask, inputImage)
? &Self::ThreaderCallback<elastix::MaskCondition::HasSameImageDomain>
: &Self::ThreaderCallback<elastix::MaskCondition::HasDifferentImageDomain>,
&userData);
}
else
{
multiThreader.SetSingleMethod(&Self::ThreaderCallback<elastix::MaskCondition::IsNull>, &userData);
}
multiThreader.SingleMethodExecute();

if (mask)
Expand Down Expand Up @@ -172,43 +192,40 @@ ImageFullSampler<TInputImage>::GenerateData()


template <class TInputImage>
template <bool VUseMask>
template <elastix::MaskCondition VMaskCondition>
ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ImageFullSampler<TInputImage>::ThreaderCallback(void * const arg)
{
assert(arg);
const auto & info = *static_cast<const MultiThreaderBase::WorkUnitInfo *>(arg);

assert(info.UserData);
auto & userData = *static_cast<UserData *>(info.UserData);

const auto workUnitID = info.WorkUnitID;

if (workUnitID >= userData.WorkUnits.size())
if (const auto workUnitID = info.WorkUnitID; workUnitID < userData.WorkUnits.size())
{
return ITK_THREAD_RETURN_DEFAULT_VALUE;
GenerateDataForWorkUnit<VMaskCondition>(
userData.WorkUnits[workUnitID], userData.InputImage, userData.Mask, userData.WorldToObjectTransform);
}

GenerateDataForWorkUnit<VUseMask>(
userData.WorkUnits[workUnitID], userData.InputImage, userData.Mask, userData.WorldToObjectTransform);

return ITK_THREAD_RETURN_DEFAULT_VALUE;
}


template <class TInputImage>
template <bool VUseMask>
template <elastix::MaskCondition VMaskCondition>
void
ImageFullSampler<TInputImage>::GenerateDataForWorkUnit(WorkUnit & workUnit,
const InputImageType & inputImage,
const MaskType * const mask,
const WorldToObjectTransformType * const worldToObjectTransform)
{
assert((mask == nullptr) == (!VUseMask));
assert((worldToObjectTransform == nullptr) == (!VUseMask));
assert((mask == nullptr) == (VMaskCondition == elastix::MaskCondition::IsNull));
assert((worldToObjectTransform == nullptr) == (VMaskCondition != elastix::MaskCondition::HasDifferentImageDomain));

auto * samples = workUnit.Samples;

[[maybe_unused]] const auto * const maskImage =
(VMaskCondition == elastix::MaskCondition::HasSameImageDomain) ? mask->GetImage() : nullptr;

/** Simply loop over the image and store all samples in the container. */
for (ImageRegionConstIteratorWithIndex<InputImageType> iter(&inputImage, workUnit.imageRegion); !iter.IsAtEnd();
++iter)
Expand All @@ -221,7 +238,22 @@ ImageFullSampler<TInputImage>::GenerateDataForWorkUnit(WorkUnit &

using RealType = typename ImageSampleType::RealType;

if constexpr (VUseMask)
if constexpr (VMaskCondition == elastix::MaskCondition::IsNull)
{
// Store sample in container.
*samples = { point, static_cast<RealType>(inputImage.GetPixel(index)) };
++samples;
}
if constexpr (VMaskCondition == elastix::MaskCondition::HasSameImageDomain)
{
if (maskImage->GetPixel(index) != 0)
{
// Store sample in container.
*samples = { point, static_cast<RealType>(inputImage.GetPixel(index)) };
++samples;
}
}
if constexpr (VMaskCondition == elastix::MaskCondition::HasDifferentImageDomain)
{
// Equivalent to `mask->IsInsideInWorldSpace(point)`, but much faster.
if (mask->MaskType::IsInsideInObjectSpace(
Expand All @@ -232,15 +264,9 @@ ImageFullSampler<TInputImage>::GenerateDataForWorkUnit(WorkUnit &
++samples;
}
}
else
{
// Store sample in container.
*samples = { point, static_cast<RealType>(inputImage.GetPixel(index)) };
++samples;
}
}

if constexpr (VUseMask)
if constexpr (VMaskCondition != elastix::MaskCondition::IsNull)
{
workUnit.NumberOfSamples = samples - workUnit.Samples;
}
Expand Down

0 comments on commit b7029ed

Please sign in to comment.