Skip to content

Commit

Permalink
Merge 55e77c9 into d860bc6
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenegriffin authored Jul 26, 2021
2 parents d860bc6 + 55e77c9 commit b5b6af9
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 120 deletions.
100 changes: 49 additions & 51 deletions UnitTest/tests/namedproptest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,52 +60,35 @@ namespace namedproptest
TEST_METHOD(Test_Match)
{
// Test all forms of match
Assert::AreEqual(true, formStorage1->match(formStorage1, true, true, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, true, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, true, false, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, true, true, false));
Assert::AreEqual(true, formStorage1->match(formStorage1, true, false, false));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, true, false));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, false, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, false, false));
Assert::AreEqual(true, formStorage1->match(formStorage1, true, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, true));
Assert::AreEqual(true, formStorage1->match(formStorage1, true, false));
Assert::AreEqual(true, formStorage1->match(formStorage1, false, false));

// Odd comparisons
Assert::AreEqual(
true,
formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1111, sig1), true, true, true));
true, formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1111, sig1), true, true));
Assert::AreEqual(
true,
formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1112, sig1), true, false, true));
true, formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1112, sig1), false, true));
Assert::AreEqual(
true,
formStorage1->match(
cache::namedPropCacheEntry::make(&pageDirStreamID, 0x1111, sig1), true, true, false));
formStorage1->match(cache::namedPropCacheEntry::make(&pageDirStreamID, 0x1111, sig1), true, false));
Assert::AreEqual(
true,
formStorage1->match(
cache::namedPropCacheEntry::make(&formStorageName, 0x1111, sig1), true, true, false));
formStorage1->match(cache::namedPropCacheEntry::make(&formStorageName, 0x1111, sig1), true, false));
Assert::AreEqual(
false,
formStorage1->match(
cache::namedPropCacheEntry::make(&formStorageName, 0x1111, sig1), true, true, true));
formStorage1->match(cache::namedPropCacheEntry::make(&formStorageName, 0x1111, sig1), true, true));

// Should fail
Assert::AreEqual(
false,
formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1110, sig1), true, true, true));
false, formStorage1->match(cache::namedPropCacheEntry::make(&formStorageID, 0x1110, sig1), true, true));
Assert::AreEqual(
false,
formStorage1->match(
cache::namedPropCacheEntry::make(&pageDirStreamID, 0x1111, sig1), true, true, true));
Assert::AreEqual(false, formStorage1->match(nullptr, true, true, true));
Assert::AreEqual(false, formStorage1->match(formStorage2, true, true, true));
Assert::AreEqual(false, formStorage1->match(formStorageLog, true, true, true));

// Should all work
Assert::AreEqual(true, formStorage1->match(formStorage2, false, true, true));
Assert::AreEqual(true, formStorage1->match(formStorage2, false, false, true));
Assert::AreEqual(true, formStorage1->match(formStorage2, false, true, false));
Assert::AreEqual(true, formStorage1->match(formStorage2, false, false, false));
formStorage1->match(cache::namedPropCacheEntry::make(&pageDirStreamID, 0x1111, sig1), true, true));
Assert::AreEqual(false, formStorage1->match(nullptr, true, true));
Assert::AreEqual(false, formStorage1->match(formStorage2, true, true));
Assert::AreEqual(false, formStorage1->match(formStorageLog, true, true));

// Compare given a signature, MAPINAMEID
// _Check_return_ bool match(_In_ const std::vector<BYTE>& _sig, _In_ const MAPINAMEID& _mapiNameId) const;
Expand Down Expand Up @@ -134,41 +117,56 @@ namespace namedproptest
Assert::AreEqual(false, formStorageProp->match(0x1111, formStorageName2));

// String prop
Assert::AreEqual(true, formStorageProp->match(formStorageProp, true, true, true));
Assert::AreEqual(false, formStorageProp->match(formStorageProp1, true, true, true));
Assert::AreEqual(true, formStorageProp->match(formStorageProp, true, true));
Assert::AreEqual(false, formStorageProp->match(formStorageProp1, true, true));
}

TEST_METHOD(Test_Cache)
{
cache::namedPropCache::add(ids1, sig1);
cache::namedPropCache::add(ids1, {});
cache::namedPropCache::add(ids2, {});
cache::namedPropCache::add(ids1, sig1); // Add prop1, prop2 with signature
cache::namedPropCache::add(ids1, {}); // Try to add them again without signature - this is a no-op
cache::namedPropCache::add(ids2, {}); // Again, adding without a signature is a no op

Assert::AreEqual(true, cache::namedPropCache::find(prop1, true, true)->match(prop1, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(prop2, true, true)->match(prop2, true, true));
Assert::AreEqual(
true, cache::namedPropCache::find(prop1, true, true, true)->match(prop1, true, true, true));
Assert::AreEqual(
true, cache::namedPropCache::find(prop2, true, true, true)->match(prop2, true, true, true));
Assert::AreEqual(
true, cache::namedPropCache::find(prop3, true, true, true)->match(prop3, true, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(0x1111, formStorageID)->match(prop1, true, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(sig1, 0x1111)->match(prop1, true, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(sig1, formStorageID)->match(prop1, true, true, true));
true,
cache::namedPropCache::find(prop3, true, true)
->match(cache::namedPropCacheEntry::empty(), true, true)); // Shouldn't find prop3 in the cache
Assert::AreEqual(true, cache::namedPropCache::find(0x1111, formStorageID)->match(prop1, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(sig1, 0x1111)->match(prop1, true, true));
Assert::AreEqual(true, cache::namedPropCache::find(sig1, formStorageID)->match(prop1, true, true));

Assert::AreEqual(
true,
cache::namedPropCache::find(0x1110, formStorageID)
->match(cache::namedPropCacheEntry::empty(), true, true, true));
->match(cache::namedPropCacheEntry::empty(), true, true));

Assert::AreEqual(false, cache::namedPropCache::find(0x1110, formStorageID)->match(prop1, true, true, true));
Assert::AreEqual(false, cache::namedPropCache::find(sig2, 0x1111)->match(prop1, true, true, true));
Assert::AreEqual(false, cache::namedPropCache::find(sig2, formStorageID)->match(prop1, true, true, true));
Assert::AreEqual(false, cache::namedPropCache::find(0x1110, formStorageID)->match(prop1, true, true));
Assert::AreEqual(false, cache::namedPropCache::find(sig2, 0x1111)->match(prop1, true, true));
Assert::AreEqual(false, cache::namedPropCache::find(sig2, formStorageID)->match(prop1, true, true));

// None of these should match prop3 since we never cached it!
Assert::AreEqual(
true,
cache::namedPropCache::find(0x3333, pageDirStreamID)
->match(cache::namedPropCacheEntry::empty(), true, true));
Assert::AreEqual(
true, cache::namedPropCache::find(0x3333, pageDirStreamID)->match(prop3, true, true, true));
Assert::AreEqual(true, cache::namedPropCache::find({}, 0x3333)->match(prop3, true, true, true));
true, cache::namedPropCache::find({}, 0x3333)->match(cache::namedPropCacheEntry::empty(), true, true));
Assert::AreEqual(
true,
cache::namedPropCache::find(std::vector<BYTE>{}, pageDirStreamID)->match(prop3, true, true, true));
cache::namedPropCache::find(std::vector<BYTE>{}, pageDirStreamID)
->match(cache::namedPropCacheEntry::empty(), true, true));

cache::namedPropCache::add(ids2, sig2); // Now add prop3 with a signature and our lookups should work
Assert::AreEqual(
true,
cache::namedPropCache::find(prop3, true, true)
->match(prop3, true, true)); // Shouldn't find prop3 in the cache
Assert::AreEqual(true, cache::namedPropCache::find(0x3333, pageDirStreamID)->match(prop3, true, true));
Assert::AreEqual(true, cache::namedPropCache::find({}, 0x3333)->match(prop3, true, true));
Assert::AreEqual(
true, cache::namedPropCache::find(std::vector<BYTE>{}, pageDirStreamID)->match(prop3, true, true));
}

TEST_METHOD(Test_Valid)
Expand Down
52 changes: 14 additions & 38 deletions core/mapi/cache/namedPropCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,50 +262,29 @@ namespace cache
}
} // namespace directMapi

std::list<std::shared_ptr<namedPropCacheEntry>>& namedPropCache::getCache() noexcept
std::vector<std::shared_ptr<namedPropCacheEntry>>& namedPropCache::getCache() noexcept
{
// We keep a list of named prop cache entries
static std::list<std::shared_ptr<namedPropCacheEntry>> cache;
static std::vector<std::shared_ptr<namedPropCacheEntry>> cache;
return cache;
}

_Check_return_ std::shared_ptr<namedPropCacheEntry>
namedPropCache::find(const std::function<bool(const std::shared_ptr<namedPropCacheEntry>&)>& compare)
{
const auto& cache = getCache();
const auto entry =
find_if(cache.begin(), cache.end(), [compare](const auto& _entry) { return compare(_entry); });

if (entry != cache.end())
{
output::DebugPrint(output::dbgLevel::NamedPropCache, L"find: found match\n");
return *entry;
}
else
{
output::DebugPrint(output::dbgLevel::NamedPropCache, L"find: no match\n");
return namedPropCacheEntry::empty();
}
}

_Check_return_ std::shared_ptr<namedPropCacheEntry> namedPropCache::find(
const std::shared_ptr<cache::namedPropCacheEntry>& entry,
bool bMatchSig,
bool bMatchID,
bool bMatchName)
{
if (fIsSet(output::dbgLevel::NamedPropCache))
{
output::DebugPrint(
output::dbgLevel::NamedPropCache,
L"find: bMatchSig=%d, bMatchID=%d, bMatchName=%d\n",
bMatchSig,
L"find: bMatchID=%d, bMatchName=%d\n",
bMatchID,
bMatchName);
entry->output();
}

return find([&](const auto& _entry) { return _entry->match(entry, bMatchSig, bMatchID, bMatchName); });
return cache::find(getCache(), [&](const auto& _entry) { return _entry->match(entry, bMatchID, bMatchName); });
}

_Check_return_ std::shared_ptr<namedPropCacheEntry>
Expand All @@ -329,7 +308,7 @@ namespace cache
}
}

return find([&](const auto& _entry) { return _entry->match(_sig, _mapiNameId); });
return cache::find(getCache(), [&](const auto& _entry) { return _entry->match(_sig, _mapiNameId); });
}

_Check_return_ std::shared_ptr<namedPropCacheEntry>
Expand All @@ -349,7 +328,7 @@ namespace cache
}
}

return find([&](const auto& _entry) { return _entry->match(_sig, _ulPropID); });
return cache::find(getCache(), [&](const auto& _entry) { return _entry->match(_sig, _ulPropID); });
}

_Check_return_ std::shared_ptr<namedPropCacheEntry>
Expand All @@ -365,7 +344,7 @@ namespace cache
nameidString.c_str());
}

return find([&](const auto& _entry) { return _entry->match(_ulPropID, _mapiNameId); });
return cache::find(getCache(), [&](const auto& _entry) { return _entry->match(_ulPropID, _mapiNameId); });
}

// Add a mapping to the cache if it doesn't already exist
Expand All @@ -374,6 +353,9 @@ namespace cache
void
namedPropCache::add(const std::vector<std::shared_ptr<namedPropCacheEntry>>& entries, const std::vector<BYTE>& sig)
{
// Don't bother adding entries to the cache if they have no signature - we cannot trust entries without a signature.
if (sig.empty()) return;

auto& cache = getCache();
for (auto& entry : entries)
{
Expand All @@ -386,15 +368,8 @@ namespace cache
}

auto match = std::shared_ptr<namedPropCacheEntry>{};
if (sig.empty())
{
match = find(entry, false, true, true);
}
else
{
entry->setSig(sig);
match = find(entry, true, true, true);
}
entry->setSig(sig);
match = find(entry, true, true);

if (!namedPropCacheEntry::valid(match))
{
Expand Down Expand Up @@ -428,7 +403,8 @@ namespace cache
if (!lpMAPIProp) return {};

// If this is a get all names call, we have to go direct to MAPI since we cannot trust the cache is full.
if (!lpPropTags)
// Same if we don't have a signature at all as anything cached could be wrong
if (!lpPropTags || sig.empty())
{
output::DebugPrint(output::dbgLevel::NamedPropCache, L"GetNamesFromIDs: making direct all for all props\n");

Expand Down
12 changes: 2 additions & 10 deletions core/mapi/cache/namedPropCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ namespace cache
ULONG ulFlags,
std::vector<std::shared_ptr<namedPropCacheEntry>> &names);

// Returns a vector of NamedPropCacheEntry for the input tags
// Sourced directly from MAPI
_Check_return_ std::vector<std::shared_ptr<namedPropCacheEntry>>
GetNamesFromIDs(_In_ LPMAPIPROP lpMAPIProp, _In_ const std::vector<ULONG> tags, ULONG ulFlags);

// Returns a vector of tags for the input names
// Sourced directly from MAPI
_Check_return_ LPSPropTagArray
Expand All @@ -35,14 +30,11 @@ namespace cache
class namedPropCache
{
private:
static std::list<std::shared_ptr<namedPropCacheEntry>>& getCache() noexcept;

_Check_return_ static std::shared_ptr<namedPropCacheEntry>
find(const std::function<bool(const std::shared_ptr<namedPropCacheEntry>&)>& compare);
static std::vector<std::shared_ptr<namedPropCacheEntry>>& getCache() noexcept;

public:
_Check_return_ static std::shared_ptr<namedPropCacheEntry>
find(const std::shared_ptr<cache::namedPropCacheEntry>& entry, bool bMatchSig, bool bMatchID, bool bMatchName);
find(const std::shared_ptr<cache::namedPropCacheEntry>& entry, bool bMatchID, bool bMatchName);
_Check_return_ static std::shared_ptr<namedPropCacheEntry>
find(_In_ const std::vector<BYTE>& _sig, _In_ const MAPINAMEID& _mapiNameId);
_Check_return_ static std::shared_ptr<namedPropCacheEntry>
Expand Down
40 changes: 28 additions & 12 deletions core/mapi/cache/namedProps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,11 @@ namespace cache
}
}

_Check_return_ bool namedPropCacheEntry::match(
const std::shared_ptr<namedPropCacheEntry>& entry,
bool bMatchSig,
bool bMatchID,
bool bMatchName) const
_Check_return_ bool
namedPropCacheEntry::match(const std::shared_ptr<namedPropCacheEntry>& entry, bool bMatchID, bool bMatchName) const
{
if (!entry) return false;
if (bMatchSig && entry->sig != sig) return false;
if (entry->sig != sig) return false;
if (bMatchID && entry->ulPropID != ulPropID) return false;

if (bMatchName)
Expand Down Expand Up @@ -202,9 +199,10 @@ namespace cache
ULONG ulFlags)
{
if (!lpMAPIProp) return {};
const auto sigValid = sig && sig->lpb && sig->cb;

// Check if we're bypassing the cache:
if (!registry::cacheNamedProps ||
if (!sigValid || !registry::cacheNamedProps ||
// None of my code uses these flags, but bypass the cache if we see them
ulFlags)
{
Expand All @@ -214,7 +212,7 @@ namespace cache
}

auto sigv = std::vector<BYTE>{};
if (sig && sig->lpb && sig->cb) sigv = {sig->lpb, sig->lpb + sig->cb};
if (sigValid) sigv = {sig->lpb, sig->lpb + sig->cb};
return namedPropCache::GetNamesFromIDs(lpMAPIProp, sigv, lpPropTags);
}

Expand All @@ -239,7 +237,7 @@ namespace cache

if (lpProp && PT_BINARY == PROP_TYPE(lpProp->ulPropTag))
{
const auto &bin = mapi::getBin(lpProp);
const auto& bin = mapi::getBin(lpProp);
sig = {bin.lpb, bin.lpb + bin.cb};
}

Expand All @@ -253,15 +251,15 @@ namespace cache
NamePropNames NameIDToStrings(_In_ const MAPINAMEID* lpNameID, ULONG ulPropTag)
{
// Can't generate strings without a MAPINAMEID structure
if (!lpNameID) return {};
if (!lpNameID || !lpNameID->lpguid) return {};

auto lpNamedPropCacheEntry = std::shared_ptr<namedPropCacheEntry>{};

// If we're using the cache, look up the answer there and return
if (registry::cacheNamedProps)
{
lpNamedPropCacheEntry = namedPropCache::find(PROP_ID(ulPropTag), *lpNameID);
if (lpNamedPropCacheEntry && lpNamedPropCacheEntry->hasCachedStrings())
if (cache::namedPropCacheEntry::valid(lpNamedPropCacheEntry) && lpNamedPropCacheEntry->hasCachedStrings())
{
return lpNamedPropCacheEntry->getNamePropNames();
}
Expand Down Expand Up @@ -370,7 +368,7 @@ namespace cache
}

// We've built our strings - if we're caching, put them in the cache
if (lpNamedPropCacheEntry)
if (cache::namedPropCacheEntry::valid(lpNamedPropCacheEntry))
{
lpNamedPropCacheEntry->setNamePropNames(namePropNames);
}
Expand Down Expand Up @@ -487,4 +485,22 @@ namespace cache

return ulHighestKnown;
}

_Check_return_ std::shared_ptr<namedPropCacheEntry> find(
const std::vector<std::shared_ptr<namedPropCacheEntry>>& list,
const std::function<bool(const std::shared_ptr<namedPropCacheEntry>&)>& compare)
{
const auto entry = find_if(list.begin(), list.end(), [compare](const auto& _entry) { return compare(_entry); });

if (entry != list.end())
{
output::DebugPrint(output::dbgLevel::NamedPropCache, L"find: found match\n");
return *entry;
}
else
{
output::DebugPrint(output::dbgLevel::NamedPropCache, L"find: no match\n");
return namedPropCacheEntry::empty();
}
}
} // namespace cache
Loading

0 comments on commit b5b6af9

Please sign in to comment.