Skip to content

Commit

Permalink
Support DOS-like wildcards in -requires (#280)
Browse files Browse the repository at this point in the history
* Support DOS-like wildcards in -requires

Fixes #276. Also resolves #270 by including link to more details on versioning syntax.

* Update minor version for new feature
  • Loading branch information
heaths authored Nov 15, 2022
1 parent 3dc7847 commit 2c2ff30
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 29 deletions.
73 changes: 71 additions & 2 deletions src/vswhere.lib/CommandArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ static wstring ParseArgument(IteratorType& it, const IteratorType& end, const Co
template <class IteratorType>
static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& arr);

template <class IteratorType>
static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& literals, vector<wregex>& patterns);

const vector<wstring> CommandArgs::s_Products
{
L"Microsoft.VisualStudio.Product.Enterprise",
Expand Down Expand Up @@ -71,7 +74,7 @@ void CommandArgs::Parse(_In_ vector<CommandParser::Token> args)
}
else if (ArgumentEquals(arg.Value, L"requires"))
{
ParseArgumentArray(it, args.end(), arg, m_requires);
ParseRequiresArray(it, args.end(), arg, m_requires, m_requiresPattern);
hasSelection = true;
}
else if (ArgumentEquals(arg.Value, L"requiresAny"))
Expand Down Expand Up @@ -218,7 +221,7 @@ void CommandArgs::Parse(_In_ vector<CommandParser::Token> args)
void CommandArgs::Usage(_In_ Console& console) const
{
auto pos = m_path.find_last_of(L"\\");
auto path = ++pos != wstring::npos ? m_path.substr(pos) : m_path;
auto& path = ++pos != wstring::npos ? m_path.substr(pos) : m_path;

console.WriteLine(ResourceManager::FormatString(IDS_USAGE, path.c_str()));

Expand All @@ -231,6 +234,37 @@ void CommandArgs::Usage(_In_ Console& console) const
}
}

std::wregex CommandArgs::ParseRegex(_In_ const std::wstring& pattern) noexcept
{
// Reserve ~125% of the incoming pattern to hold any changes.
wstring accumulator;
accumulator.reserve(pattern.size() * 1.25);

for (auto it = pattern.begin(); it != pattern.end(); ++it)
{
switch (*it)
{
case L'.':
accumulator += L"\\.";
break;

case L'*':
accumulator += L".*";
break;

case L'?':
accumulator += L".";
break;

default:
accumulator += *it;
break;
}
}

return std::move(wregex(accumulator, wregex::basic | wregex::icase | wregex::nosubs));
}

static bool ArgumentEquals(_In_ const wstring& name, _In_ LPCWSTR expect)
{
_ASSERT(expect && *expect);
Expand Down Expand Up @@ -281,3 +315,38 @@ static void ParseArgumentArray(IteratorType& it, const IteratorType& end, const
arr.push_back(it->Value);
}
}

template <class IteratorType>
static void ParseRequiresArray(IteratorType& it, const IteratorType& end, const CommandParser::Token& arg, vector<wstring>& literals, vector<wregex>& patterns)
{
wstring& param = it->Value;
auto nit = next(it);

// Require arguments if the parameter is specified.
if (nit == end || CommandParser::Token::eArgument != nit->Type)
{
auto message = ResourceManager::FormatString(IDS_E_ARGREQUIRED, param.c_str());
throw win32_error(ERROR_INVALID_PARAMETER, message);
}

while (nit != end)
{
if (CommandParser::Token::eParameter == nit->Type)
{
break;
}

++it;
++nit;

if (it->Value.find(L'*', 0) == wstring::npos && it->Value.find(L'?', 0) == wstring::npos)
{
literals.push_back(it->Value);
}
else
{
auto pattern = CommandArgs::ParseRegex(it->Value);
patterns.push_back(std::move(pattern));
}
}
}
9 changes: 9 additions & 0 deletions src/vswhere.lib/CommandArgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CommandArgs
m_productsAll(obj.m_productsAll),
m_products(obj.m_products),
m_requires(obj.m_requires),
m_requiresPattern(obj.m_requiresPattern),
m_version(obj.m_version),
m_latest(obj.m_latest),
m_legacy(obj.m_legacy),
Expand Down Expand Up @@ -72,6 +73,11 @@ class CommandArgs
return m_requires;
}

const std::vector<std::wregex>& get_RequiresPattern() const noexcept
{
return m_requiresPattern;
}

const bool get_RequiresAny() const noexcept
{
return m_requiresAny;
Expand Down Expand Up @@ -157,6 +163,8 @@ class CommandArgs
void Parse(_In_ int argc, _In_ LPCWSTR argv[]);
void Usage(_In_ Console& console) const;

static std::wregex ParseRegex(_In_ const std::wstring& pattern) noexcept;

private:
static const std::vector<std::wstring> s_Products;
static const std::wstring s_Format;
Expand All @@ -168,6 +176,7 @@ class CommandArgs
bool m_productsAll;
std::vector<std::wstring> m_products;
std::vector<std::wstring> m_requires;
std::vector<std::wregex> m_requiresPattern;
bool m_requiresAny;
std::wstring m_version;
bool m_latest;
Expand Down
3 changes: 2 additions & 1 deletion src/vswhere.lib/Formatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ void Formatter::WritePackages(_In_ ISetupInstance* pInstance)
StartArray(L"packages");

SafeArray<ISetupPackageReference*> saPackages(psaPackages);
const auto packages = saPackages.Elements();
const auto& packages = saPackages.Elements();

for (const auto& package : packages)
{
Expand Down Expand Up @@ -431,6 +431,7 @@ bool Formatter::WriteProperties(_In_ ISetupPropertyStore* pProperties, _In_opt_

SafeArray<BSTR> saNames(psaNames);

// Copy the elements so we can sort them.
auto elems = saNames.Elements();
sort(elems.begin(), elems.end(), less);

Expand Down
55 changes: 32 additions & 23 deletions src/vswhere.lib/InstanceSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using namespace std;
using std::placeholders::_1;

ci_equal InstanceSelector::s_comparer;

InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyProvider& provider, _In_opt_ ISetupHelper* pHelper) :
m_args(args),
m_provider(provider),
Expand All @@ -17,7 +19,7 @@ InstanceSelector::InstanceSelector(_In_ const CommandArgs& args, _In_ ILegacyPro
m_helper = pHelper;
if (m_helper)
{
auto version = args.get_Version();
auto& version = args.get_Version();
if (!version.empty())
{
auto hr = m_helper->ParseVersionRange(version.c_str(), &m_ullMinimumVersion, &m_ullMaximumVersion);
Expand Down Expand Up @@ -224,7 +226,7 @@ bool InstanceSelector::IsProductMatch(_In_ ISetupInstance2* pInstance) const
}

// Asterisk on command line will clear the array to find any products.
const auto products = m_args.get_Products();
const auto& products = m_args.get_Products();
if (products.empty())
{
return true;
Expand All @@ -250,21 +252,19 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const
{
_ASSERT(pInstance);

const auto requires = m_args.get_Requires();
if (requires.empty())
// Create copies and erase elements as found.
auto literals = m_args.get_Requires();
auto literals_count = literals.size();

auto patterns = m_args.get_RequiresPattern();
auto patterns_count = patterns.size();

if (literals.empty() && patterns.empty())
{
// No workloads required matches every instance.
return true;
}

// Keep track of which requirements we matched.
typedef map<wstring, bool, ci_less> MapType;
MapType found;
for (const auto& require : requires)
{
found.emplace(make_pair(require, false));
}

LPSAFEARRAY psa = NULL;
auto hr = pInstance->GetPackages(&psa);
if (FAILED(hr))
Expand All @@ -277,25 +277,34 @@ bool InstanceSelector::IsWorkloadMatch(_In_ ISetupInstance2* pInstance) const
{
auto id = GetId(package);

auto it = found.find(id);
if (it != found.end())
for (auto it = literals.cbegin(); it != literals.cend(); ++it)
{
if (s_comparer(id, *it))
{
literals.erase(it);
goto next;
}
}

for (auto it = patterns.cbegin(); it != patterns.cend(); ++it)
{
it->second = true;
if (regex_match(id, *it))
{
patterns.erase(it);
goto next;
}
}

next: continue;
}

if (m_args.get_RequiresAny())
{
return any_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool
{
return pair.second;
});
return literals.size() < literals_count
|| patterns.size() < patterns_count;
}

return all_of(found.begin(), found.end(), [](MapType::const_reference pair) -> bool
{
return pair.second;
});
return literals.empty() && patterns.empty();
}

bool InstanceSelector::IsVersionMatch(_In_ ISetupInstance* pInstance) const
Expand Down
2 changes: 2 additions & 0 deletions src/vswhere.lib/InstanceSelector.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class InstanceSelector
std::vector<ISetupInstancePtr> Select(_In_opt_ IEnumSetupInstances* pEnum) const;

private:
static ci_equal s_comparer;

static std::wstring GetId(_In_ ISetupPackageReference* pPackageReference);
bool IsMatch(_In_ ISetupInstance* pInstance) const;
bool IsProductMatch(_In_ ISetupInstance2* pInstance) const;
Expand Down
2 changes: 1 addition & 1 deletion src/vswhere.lib/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const wstring& Module::get_Path() noexcept

const wstring& Module::get_FileVersion() noexcept
{
auto path = get_Path();
auto& path = get_Path();
if (path.empty())
{
return m_fileVersion;
Expand Down
3 changes: 3 additions & 0 deletions src/vswhere.lib/vswhere.lib.rc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ BEGIN
\n See https://aka.ms/vs/workloads for a list of product IDs.\
\n -requires arg One or more workload or component IDs required when finding instances.\
\n All specified IDs must be installed unless -requiresAny is specified.\
\n You can specify wildcards including ""?"" to match any one character,\
\n or ""*"" to match zero or more of any characters.\
\n See https://aka.ms/vs/workloads for a list of workload and component IDs.\
\n -requiresAny Find instances with any one or more workload or components IDs passed to -requires.\
\n -version arg A version range for instances to find. Example: [15.0,16.0) will find versions 15.*.\
\n See https://aka.ms/vswhere/versions for more information about versions.\
\n -latest Return only the newest version and last installed.\
\n -sort Sorts the instances from newest version and last installed to oldest.\
\n When used with ""find"", first instances are sorted then files are sorted lexigraphically.\
Expand Down
2 changes: 1 addition & 1 deletion src/vswhere/Program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void WriteLogo(_In_ const CommandArgs& args, _In_ Console& console, _In_ Module&
{
if (args.get_Logo())
{
const auto version = module.get_FileVersion();
const auto& version = module.get_FileVersion();
const auto nID = version.empty() ? IDS_PROGRAMINFO : IDS_PROGRAMINFOEX;

console.WriteLine(ResourceManager::FormatString(nID, NBGV_INFORMATIONAL_VERSION, version.c_str()));
Expand Down
52 changes: 52 additions & 0 deletions test/vswhere.test/CommandArgsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,56 @@ TEST_CLASS(CommandArgsTests)
Assert::IsFalse(args.get_Logo());
Assert::IsTrue(args.get_UTF8());
}

BEGIN_TEST_METHOD_ATTRIBUTE(Parse_Requires_Patterns)
TEST_WORKITEM(276)
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(Parse_Requires_Patterns)
{
CommandArgs args;
args.Parse(L"vswhere.exe -requires foo ba* qux");

const auto& literals = args.get_Requires();
const auto& patterns = args.get_RequiresPattern();

Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"foo")));
Assert::AreEqual(1, count(literals.cbegin(), literals.cend(), wstring(L"qux")));
Assert::AreEqual<size_t>(1, patterns.size());
}

BEGIN_TEST_METHOD_ATTRIBUTE(ParseRegex_Theory)
TEST_WORKITEM(276)
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(ParseRegex_Theory)
{
const wstring id = L"Foo.Bar";
vector<tuple<wstring, bool>> data =
{
{ L"Foo.Bar", true },
{ L"Foo.*", true },
{ L"*.Bar", true },
{ L"F*R", true },
{ L"foo?bar", true },
{ L"f??", false },
{ L"f??.??r", true },
{ L"*", true },
{ L".*", false },
{ L"?", false },
{ L"Baz", false },
{ L"*baz", false },
{ L"foo.bar*", true },
};

for (const auto& item : data)
{
wstring pattern;
bool expected;

tie(pattern, expected) = item;
auto re = CommandArgs::ParseRegex(pattern);
bool actual = regex_match(id, re);

Assert::AreEqual(expected, actual, format(L"\"%ls\" =~ /%ls/", id.c_str(), pattern.c_str()).c_str());
}
}
};
Loading

0 comments on commit 2c2ff30

Please sign in to comment.