-
Notifications
You must be signed in to change notification settings - Fork 35
/
ControlNetTest.cpp
66 lines (54 loc) · 2.17 KB
/
ControlNetTest.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "pch.h"
#include "CppUnitTest.h"
#include "Storage/FileIO.h"
using namespace Axodox::Graphics;
using namespace Axodox::Storage;
using namespace Axodox::MachineLearning;
using namespace Axodox::MachineLearning::Imaging::StableDiffusion;
using namespace Axodox::MachineLearning::Sessions;
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
using namespace std;
namespace Axodox::MachineLearning::Test
{
TEST_CLASS(ControlNetTest)
{
TEST_METHOD(TestControlNet)
{
StableDiffusionDirectorySessionParameters sessionParameters{ lib_folder() / "../../../models/stable_diffusion" };
ControlNetOptions options{};
//Create text embedding
{
TextEmbedder textEmbedder(sessionParameters);
auto positiveEmbedding = textEmbedder.ProcessPrompt("a clean bedroom");
auto negativeEmbedding = textEmbedder.ProcessPrompt("blurry, render");
options.TextEmbeddings.Tensor = negativeEmbedding.Concat(positiveEmbedding);
options.TextEmbeddings.Weights = { -1.f, 1.f };
}
//Load conditioning input
{
auto imagePath = (lib_folder() / "..\\..\\..\\inputs\\depth.png").lexically_normal();
auto imageData = read_file(imagePath);
auto imageTexture = TextureData::FromBuffer(imageData);
options.ConditionInput = Tensor::FromTextureData(imageTexture, ColorNormalization::LinearZeroToOne);
}
//Run ControlNet
Tensor image;
{
auto controlnetParameters = OnnxSessionParameters::Create(lib_folder() / "../../../models/controlnet", OnnxExecutorType::Dml);
ControlNetInferer controlNet{ controlnetParameters, sessionParameters };
image = controlNet.RunInference(options);
}
//Decode VAE
{
VaeDecoder vaeDecoder{ sessionParameters };
image = vaeDecoder.DecodeVae(image);
}
//Save result
{
auto imageTexture = image.ToTextureData(ColorNormalization::LinearPlusMinusOne);
auto imageBuffer = imageTexture[0].ToBuffer();
write_file(lib_folder() / "controlnet.png", imageBuffer);
}
}
};
}