Skip to content

Commit

Permalink
allow setting runtime handles (alternative) (#1002)
Browse files Browse the repository at this point in the history
This PR is an alternative way of
#997 to implement
allowing users to set runtime handles.

The basic idea is to create a JSON string representing the `Config`
class from the `RuntimeSettings`, and the `Config` class will merge
them.

Currently this PR is not working because of the following:
- The existing JSON parser does not support multiple entries in
`provider_options`
  The following JSON will fail to parse.
  ```json
  {
    "model": {
        "decoder": {
            "session_options": {
                "provider_options": [
                    {
                        "webgpu": { }
                    },
                    {
                        "dml": {}
                    }
                ]
            },
        },
     },
  }

  ```

- When trying to specify JSON overlay, two "webgpu" items does not
merge.
  genai_config.json:
  ```json
  {
    "model": {
        "decoder": {
            "session_options": {
                "provider_options": [
                    {
                        "webgpu": { "abc": "123" }
                    }
                ]
            },
        },
     },
  }

  ```

  generated overlay config:
  ```
  {
    "model": {
      "decoder": {
        "session_options": {
          "provider_options": [
            {
              "webgpu": {
                "dawnProcTable": "12345678"
              }
            }
          ]
        }
      }
    }
  }
  ```

  Result:
  

![image](https://github.com/user-attachments/assets/adf649bd-8d7b-4b2e-a8f8-6141d04eb7ee)

This PR depends on a code change to the Config class to support the
expected parsing behaviors.

---------

Co-authored-by: Ryan Hill <[email protected]>
  • Loading branch information
fs-eire and RyanUnderhill authored Oct 24, 2024
1 parent 147a311 commit d58daf0
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 10 deletions.
17 changes: 14 additions & 3 deletions src/config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "generators.h"
#include "runtime_settings.h"
#include "json.h"
#include <fstream>
#include <sstream>
Expand Down Expand Up @@ -640,7 +641,7 @@ struct RootObject_Element : JSON::Element {
JSON::Element& t_;
};

void ParseConfig(const fs::path& filename, Config& config) {
void ParseConfig(const fs::path& filename, std::string_view json_overlay, Config& config) {
std::ifstream file = filename.open(std::ios::binary | std::ios::ate);
if (!file.is_open()) {
throw std::runtime_error("Error opening " + filename.string());
Expand All @@ -662,10 +663,20 @@ void ParseConfig(const fs::path& filename, Config& config) {
oss << "Error encountered while parsing '" << filename.string() << "' " << message.what();
throw std::runtime_error(oss.str());
}

if (!json_overlay.empty()) {
try {
JSON::Parse(root_object, json_overlay);
} catch (const std::exception& message) {
std::ostringstream oss;
oss << "Error encountered while parsing config overlay: " << message.what();
throw std::runtime_error(oss.str());
}
}
}

Config::Config(const fs::path& path) : config_path{path} {
ParseConfig(path / "genai_config.json", *this);
Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} {
ParseConfig(path / "genai_config.json", json_overlay, *this);

if (model.context_length == 0)
throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0");
Expand Down
4 changes: 3 additions & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

namespace Generators {

struct RuntimeSettings;

struct Config {
Config() = default;
Config(const fs::path& path);
Config(const fs::path& path, std::string_view json_overlay);

struct Defaults {
static constexpr std::string_view InputIdsName = "input_ids";
Expand Down
3 changes: 2 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using cudaStream_t = void*;
#include "models/debugging.h"
#include "config.h"
#include "logging.h"
#include "runtime_settings.h"
#include "tensor.h"

namespace Generators {
Expand Down Expand Up @@ -135,7 +136,7 @@ std::unique_ptr<OrtGlobals>& GetOrtGlobals();
void Shutdown(); // Do this once at exit, Ort code will fail after this call
OrtEnv& GetOrtEnv();

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path);
std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings = nullptr);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Model& model);
std::shared_ptr<GeneratorParams> CreateGeneratorParams(const Config& config); // For benchmarking purposes only
std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params);
Expand Down
13 changes: 10 additions & 3 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
opts.emplace(option.first, option.second);
}
session_options.AppendExecutionProvider("QNN", opts);
} else if (provider_options.name == "web") {
} else if (provider_options.name == "webgpu") {
device_type_ = DeviceType::WEBGPU;
std::unordered_map<std::string, std::string> opts;
for (auto& option : provider_options.options) {
opts.emplace(option.first, option.second);
}
session_options.AppendExecutionProvider("WebGPU", opts);
} else
throw std::runtime_error("Unknown provider type: " + provider_options.name);
Expand Down Expand Up @@ -510,8 +513,12 @@ std::shared_ptr<MultiModalProcessor> Model::CreateMultiModalProcessor() const {
return std::make_shared<MultiModalProcessor>(*config_, *session_info_);
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path) {
auto config = std::make_unique<Config>(fs::path(config_path));
std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) {
std::string config_overlay;
if (settings) {
config_overlay = settings->GenerateConfigOverlay();
}
auto config = std::make_unique<Config>(fs::path(config_path), config_overlay);

if (config->model.type == "gpt2")
return std::make_shared<Gpt_Model>(std::move(config), ort_env);
Expand Down
22 changes: 22 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,34 @@ inline void OgaCheckResult(OgaResult* result) {
}
}

struct OgaRuntimeSettings : OgaAbstract {
static std::unique_ptr<OgaRuntimeSettings> Create() {
OgaRuntimeSettings* p;
OgaCheckResult(OgaCreateRuntimeSettings(&p));
return std::unique_ptr<OgaRuntimeSettings>(p);
}

void SetHandle(const char* name, void* handle) {
OgaCheckResult(OgaRuntimeSettingsSetHandle(this, name, handle));
}
void SetHandle(const std::string& name, void* handle) {
SetHandle(name.c_str(), handle);
}

static void operator delete(void* p) { OgaDestroyRuntimeSettings(reinterpret_cast<OgaRuntimeSettings*>(p)); }
};

struct OgaModel : OgaAbstract {
static std::unique_ptr<OgaModel> Create(const char* config_path) {
OgaModel* p;
OgaCheckResult(OgaCreateModel(config_path, &p));
return std::unique_ptr<OgaModel>(p);
}
static std::unique_ptr<OgaModel> Create(const char* config_path, const OgaRuntimeSettings& settings) {
OgaModel* p;
OgaCheckResult(OgaCreateModelWithRuntimeSettings(config_path, &settings, &p));
return std::unique_ptr<OgaModel>(p);
}

std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) const {
OgaSequences* p;
Expand Down
28 changes: 26 additions & 2 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ort_genai_c.h"
#include "generators.h"
#include "models/model.h"
#include "runtime_settings.h"
#include "search.h"

namespace Generators {
Expand Down Expand Up @@ -134,15 +135,26 @@ OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudi
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) {
OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out) {
OGA_TRY
*out = reinterpret_cast<OgaRuntimeSettings*>(Generators::CreateRuntimeSettings().release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out) {
OGA_TRY
auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path);
auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path, reinterpret_cast<const Generators::RuntimeSettings*>(settings));
model->external_owner_ = model;
*out = reinterpret_cast<OgaModel*>(model.get());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) {
return OgaCreateModelWithRuntimeSettings(config_path, nullptr, out);
}

OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out) {
OGA_TRY
auto params = std::make_shared<Generators::GeneratorParams>(*reinterpret_cast<const Generators::Model*>(model));
Expand All @@ -152,6 +164,14 @@ OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGener
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle) {
OGA_TRY
Generators::RuntimeSettings* settings_ = reinterpret_cast<Generators::RuntimeSettings*>(settings);
settings_->handles_[handle_name] = handle;
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value) {
OGA_TRY
Generators::SetSearchNumber(reinterpret_cast<Generators::GeneratorParams*>(generator_params)->search, name, value);
Expand Down Expand Up @@ -623,4 +643,8 @@ void OGA_API_CALL OgaDestroyNamedTensors(OgaNamedTensors* p) {
void OGA_API_CALL OgaDestroyAdapters(OgaAdapters* p) {
reinterpret_cast<Generators::Adapters*>(p)->external_owner_ = nullptr;
}

void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* p) {
delete reinterpret_cast<Generators::RuntimeSettings*>(p);
}
}
32 changes: 32 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ typedef enum OgaElementType {
typedef struct OgaResult OgaResult;
typedef struct OgaGeneratorParams OgaGeneratorParams;
typedef struct OgaGenerator OgaGenerator;
typedef struct OgaRuntimeSettings OgaRuntimeSettings;
typedef struct OgaModel OgaModel;
// OgaSequences is an array of token arrays where the number of token arrays can be obtained using
// OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount.
Expand Down Expand Up @@ -149,6 +150,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_pat

OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios);

/*
* \brief Creates a runtime settings instance to be used to create a model.
* \param[out] out The created runtime settings.
* \return OgaResult containing the error message if the creation of the runtime settings failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out);
/*
* \brief Destroys the given runtime settings.
* \param[in] settings The runtime settings to be destroyed.
*/
OGA_EXPORT void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* settings);

/*
* \brief Sets a specific runtime handle for the runtime settings.
* \param[in] settings The runtime settings to set the device type.
* \param[in] handle_name The name of the handle to set for the runtime settings.
* \param[in] handle The value of handle to set for the runtime settings.
* \return OgaResult containing the error message if the setting of the device type failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle);

/*
* \brief Creates a model from the given configuration directory and device type.
* \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8.
Expand All @@ -158,6 +180,16 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios);
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out);

/*
* \brief Creates a model from the given configuration directory, runtime settings and device type.
* \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8.
* \param[in] settings The runtime settings to use for the model.
* \param[in] device_type The device type to use for the model.
* \param[out] out The created model.
* \return OgaResult containing the error message if the model creation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out);

/*
* \brief Destroys the given model.
* \param[in] model The model to be destroyed.
Expand Down
43 changes: 43 additions & 0 deletions src/runtime_settings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "runtime_settings.h"

namespace Generators {

std::unique_ptr<RuntimeSettings> CreateRuntimeSettings() {
return std::make_unique<RuntimeSettings>();
}

std::string RuntimeSettings::GenerateConfigOverlay() const {
// #if USE_WEBGPU
constexpr std::string_view webgpu_overlay_pre = R"({
"model": {
"decoder": {
"session_options": {
"provider_options": [
{
"webgpu": {
"dawnProcTable": ")";
constexpr std::string_view webgpu_overlay_post = R"("
}
}
]
}
}
}
}
)";

auto it = handles_.find("dawnProcTable");
if (it != handles_.end()) {
void* dawn_proc_table_handle = it->second;
std::string overlay;
overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20); // Optional small optimization of buffer size
overlay += webgpu_overlay_pre;
overlay += std::to_string((size_t)(dawn_proc_table_handle));
overlay += webgpu_overlay_post;
return overlay;
}

return {};
}

} // namespace Generators
20 changes: 20 additions & 0 deletions src/runtime_settings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <string>
#include <memory>
#include <unordered_map>

namespace Generators {

// This struct should only be used for runtime settings that are not able to be put into config.
struct RuntimeSettings {
RuntimeSettings() = default;

std::string GenerateConfigOverlay() const;

std::unordered_map<std::string, void*> handles_;
};

std::unique_ptr<RuntimeSettings> CreateRuntimeSettings();

} // namespace Generators

0 comments on commit d58daf0

Please sign in to comment.