Skip to content

Commit

Permalink
Refactor WinMLAPI Tests to build both google and taef test based on p…
Browse files Browse the repository at this point in the history
…reprocessor definition (#2829)

* Add winml macro wrappers on top of google test macros

* change test methods to disabled

* Add custom winml macros for both taef and google tests

* PR comments

* Refactor winml api tests

* Move additional gtest specific macro definition into googleTestMacros.h
  • Loading branch information
ryanlai2 authored Jan 15, 2020
1 parent dbe7d97 commit dcdebb4
Show file tree
Hide file tree
Showing 11 changed files with 723 additions and 512 deletions.
1 change: 1 addition & 0 deletions cmake/winml_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ add_winml_test(
SOURCES ${winml_test_api_src}
LIBS winml_test_common
)
target_compile_definitions(winml_test_api PRIVATE BUILD_GOOGLE_TEST)
target_precompiled_header(winml_test_api testPch.h)

if (onnxruntime_USE_DML)
Expand Down
52 changes: 20 additions & 32 deletions winml/test/api/APITest.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,25 @@
//-----------------------------------------------------------------------------

#pragma once
#include "fileHelpers.h"
namespace APITest {
static void LoadModel(const std::wstring& modelPath,
winrt::Windows::AI::MachineLearning::LearningModel& learningModel) {
std::wstring fullPath = FileHelpers::GetModulePath() + modelPath;
learningModel = winrt::Windows::AI::MachineLearning::LearningModel::LoadFromFilePath(fullPath);
};

#include <gtest/gtest.h>

class APITest : public ::testing::Test
{
protected:
void LoadModel(const std::wstring& modelPath)
{
std::wstring fullPath = FileHelpers::GetModulePath() + modelPath;
m_model = winrt::Windows::AI::MachineLearning::LearningModel::LoadFromFilePath(fullPath);
}

winrt::Windows::AI::MachineLearning::LearningModel m_model = nullptr;
winrt::Windows::AI::MachineLearning::LearningModelDevice m_device = nullptr;
winrt::Windows::AI::MachineLearning::LearningModelSession m_session = nullptr;

uint64_t GetAdapterIdQuadPart()
{
LARGE_INTEGER id;
id.LowPart = m_device.AdapterId().LowPart;
id.HighPart = m_device.AdapterId().HighPart;
return id.QuadPart;
};

_LUID GetAdapterIdAsLUID()
{
_LUID id;
id.LowPart = m_device.AdapterId().LowPart;
id.HighPart = m_device.AdapterId().HighPart;
return id;
}

bool m_runGPUTests = true;
static uint64_t GetAdapterIdQuadPart(winrt::Windows::AI::MachineLearning::LearningModelDevice& device) {
LARGE_INTEGER id;
id.LowPart = device.AdapterId().LowPart;
id.HighPart = device.AdapterId().HighPart;
return id.QuadPart;
};

static _LUID GetAdapterIdAsLUID(winrt::Windows::AI::MachineLearning::LearningModelDevice& device) {
_LUID id;
id.LowPart = device.AdapterId().LowPart;
id.HighPart = device.AdapterId().HighPart;
return id;
}
}; // namespace APITest
Loading

0 comments on commit dcdebb4

Please sign in to comment.