Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow setting runtime handles (alternative) #1002

Merged
merged 8 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "generators.h"
#include "runtime_settings.h"
#include "json.h"
#include <fstream>
#include <sstream>
Expand Down Expand Up @@ -664,8 +665,23 @@ void ParseConfig(const fs::path& filename, Config& config) {
}
}

Config::Config(const fs::path& path) : config_path{path} {
void ParseConfig(std::string_view json, Config& config) {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
Root_Element root{config};
RootObject_Element root_object{root};
try {
JSON::Parse(root_object, json);
} catch (const std::exception& message) {
std::ostringstream oss;
oss << "Error encountered while parsing JSON " << message.what();
throw std::runtime_error(oss.str());
}
}

Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} {
ParseConfig(path / "genai_config.json", *this);
if (!json_overlay.empty()) {
ParseConfig(json_overlay, *this);
}

if (model.context_length == 0)
throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0");
Expand Down
4 changes: 3 additions & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

namespace Generators {

struct RuntimeSettings;

struct Config {
Config() = default;
Config(const fs::path& path);
Config(const fs::path& path, std::string_view json_overlay);

struct Defaults {
static constexpr std::string_view InputIdsName = "input_ids";
Expand Down
3 changes: 2 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using cudaStream_t = void*;
#include "models/debugging.h"
#include "config.h"
#include "logging.h"
#include "runtime_settings.h"
#include "tensor.h"

namespace Generators {
Expand Down Expand Up @@ -135,7 +136,7 @@ std::unique_ptr<OrtGlobals>& GetOrtGlobals();
void Shutdown(); // Do this once at exit, Ort code will fail after this call
OrtEnv& GetOrtEnv();

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path);
std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings = nullptr);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); // For benchmarking purposes only
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
Expand Down
11 changes: 9 additions & 2 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
} else if (provider_options.name == "web") {
device_type_ = DeviceType::WEBGPU;
std::unordered_map<std::string, std::string> opts;
for (auto& option : provider_options.options) {
opts.emplace(option.first, option.second);
}
session_options.AppendExecutionProvider("WebGPU", opts);
} else
throw std::runtime_error("Unknown provider type: " + provider_options.name);
Expand Down Expand Up @@ -510,8 +513,12 @@ std::shared_ptr<MultiModalProcessor> Model::CreateMultiModalProcessor() const {
return std::make_shared<MultiModalProcessor>(*config_, *session_info_);
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path) {
auto config = std::make_unique<Config>(fs::path(config_path));
std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) {
std::string config_overlay;
if (settings) {
config_overlay = settings->GenerateConfigOverlay();
}
auto config = std::make_unique<Config>(fs::path(config_path), config_overlay);

if (config->model.type == "gpt2")
return std::make_shared<Gpt_Model>(std::move(config), ort_env);
Expand Down
22 changes: 22 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,34 @@ inline void OgaCheckResult(OgaResult* result) {
}
}

struct OgaRuntimeSettings : OgaAbstract {
static std::unique_ptr<OgaRuntimeSettings> Create() {
OgaRuntimeSettings* p;
OgaCheckResult(OgaCreateRuntimeSettings(&p));
return std::unique_ptr<OgaRuntimeSettings>(p);
}

void SetHandle(const char* name, void* handle) {
OgaCheckResult(OgaRuntimeSettingsSetHandle(this, name, handle));
}
void SetHandle(const std::string& name, void* handle) {
SetHandle(name.c_str(), handle);
}

static void operator delete(void* p) { OgaDestroyRuntimeSettings(reinterpret_cast<OgaRuntimeSettings*>(p)); }
};

struct OgaModel : OgaAbstract {
static std::unique_ptr<OgaModel> Create(const char* config_path) {
OgaModel* p;
OgaCheckResult(OgaCreateModel(config_path, &p));
return std::unique_ptr<OgaModel>(p);
}
static std::unique_ptr<OgaModel> Create(const char* config_path, const OgaRuntimeSettings& settings) {
OgaModel* p;
OgaCheckResult(OgaCreateModelWithRuntimeSettings(config_path, &settings, &p));
return std::unique_ptr<OgaModel>(p);
}

std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) const {
OgaSequences* p;
Expand Down
28 changes: 26 additions & 2 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ort_genai_c.h"
#include "generators.h"
#include "models/model.h"
#include "runtime_settings.h"
#include "search.h"

namespace Generators {
Expand Down Expand Up @@ -134,15 +135,26 @@ OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudi
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) {
OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out) {
OGA_TRY
*out = reinterpret_cast<OgaRuntimeSettings*>(Generators::CreateRuntimeSettings().release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out) {
OGA_TRY
auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path);
auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path, reinterpret_cast<const Generators::RuntimeSettings*>(settings));
model->external_owner_ = model;
*out = reinterpret_cast<OgaModel*>(model.get());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) {
return OgaCreateModelWithRuntimeSettings(config_path, nullptr, out);
}

OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out) {
OGA_TRY
auto params = std::make_shared<Generators::GeneratorParams>(*reinterpret_cast<const Generators::Model*>(model));
Expand All @@ -152,6 +164,14 @@ OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGener
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle) {
OGA_TRY
Generators::RuntimeSettings* settings_ = reinterpret_cast<Generators::RuntimeSettings*>(settings);
settings_->handles_[handle_name] = handle;
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value) {
OGA_TRY
Generators::SetSearchNumber(reinterpret_cast<Generators::GeneratorParams*>(generator_params)->search, name, value);
Expand Down Expand Up @@ -623,4 +643,8 @@ void OGA_API_CALL OgaDestroyNamedTensors(OgaNamedTensors* p) {
void OGA_API_CALL OgaDestroyAdapters(OgaAdapters* p) {
reinterpret_cast<Generators::Adapters*>(p)->external_owner_ = nullptr;
}

void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* p) {
delete reinterpret_cast<Generators::RuntimeSettings*>(p);
}
}
32 changes: 32 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ typedef enum OgaElementType {
typedef struct OgaResult OgaResult;
typedef struct OgaGeneratorParams OgaGeneratorParams;
typedef struct OgaGenerator OgaGenerator;
typedef struct OgaRuntimeSettings OgaRuntimeSettings;
typedef struct OgaModel OgaModel;
// OgaSequences is an array of token arrays where the number of token arrays can be obtained using
// OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount.
Expand Down Expand Up @@ -149,6 +150,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_pat

OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios);

/*
* \brief Creates a runtime settings instance to be used to create a model.
* \param[out] out The created runtime settings.
* \return OgaResult containing the error message if the creation of the runtime settings failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out);
/*
* \brief Destroys the given runtime settings.
* \param[in] settings The runtime settings to be destroyed.
*/
OGA_EXPORT void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* settings);

/*
* \brief Sets a specific runtime handle for the runtime settings.
* \param[in] settings The runtime settings to set the device type.
* \param[in] handle_name The name of the handle to set for the runtime settings.
* \param[in] handle The value of handle to set for the runtime settings.
* \return OgaResult containing the error message if the setting of the device type failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle);

/*
* \brief Creates a model from the given configuration directory and device type.
* \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8.
Expand All @@ -158,6 +180,16 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios);
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out);

/*
* \brief Creates a model from the given configuration directory, runtime settings and device type.
* \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8.
* \param[in] settings The runtime settings to use for the model.
* \param[in] device_type The device type to use for the model.
* \param[out] out The created model.
* \return OgaResult containing the error message if the model creation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out);

/*
* \brief Destroys the given model.
* \param[in] model The model to be destroyed.
Expand Down
43 changes: 43 additions & 0 deletions src/runtime_settings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "runtime_settings.h"

namespace Generators {

std::unique_ptr<RuntimeSettings> CreateRuntimeSettings() {
return std::make_unique<RuntimeSettings>();
}

std::string RuntimeSettings::GenerateConfigOverlay() const {
// #if USE_WEBGPU
constexpr std::string_view webgpu_overlay_pre = R"({
"model": {
"decoder": {
"session_options": {
"provider_options": [
{
"webgpu": {
"dawnProcTable": ")";
constexpr std::string_view webgpu_overlay_post = R"("
}
}
]
}
}
}
}
)";

auto it = handles_.find("dawnProcTable");
if (it != handles_.end()) {
void* dawn_proc_table_handle = it->second;
std::string overlay;
overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20);
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
overlay += webgpu_overlay_pre;
overlay += std::to_string((size_t)(dawn_proc_table_handle));
overlay += webgpu_overlay_post;
return overlay;
}

return {};
}

} // namespace Generators
20 changes: 20 additions & 0 deletions src/runtime_settings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <string>
#include <memory>
#include <unordered_map>

namespace Generators {

// This struct should only be used for runtime settings that are not able to be put into config.
struct RuntimeSettings {
RuntimeSettings() = default;

std::string GenerateConfigOverlay() const;

std::unordered_map<std::string, void*> handles_;
};

std::unique_ptr<RuntimeSettings> CreateRuntimeSettings();

} // namespace Generators
Loading