diff --git a/examples/llm-writer-app/data/main.ui b/examples/llm-writer-app/data/main.ui index e04e1d7..f6ae016 100644 --- a/examples/llm-writer-app/data/main.ui +++ b/examples/llm-writer-app/data/main.ui @@ -6,22 +6,43 @@ 480 640 - - True - vertical + + true - + true - true - true - fill + vertical - + true true true fill - word + + + true + true + true + fill + word + + + + + + + + + true + true + end + end + + + false + 0.0 + Starting Download + true diff --git a/examples/llm-writer-app/src/main.js b/examples/llm-writer-app/src/main.js index 6049df1..4471866 100644 --- a/examples/llm-writer-app/src/main.js +++ b/examples/llm-writer-app/src/main.js @@ -39,16 +39,137 @@ const STATE_TEXT_EDITOR = 0; const STATE_PREDICTING = 1; const STATE_WAITING = 2; +const list_store_from_rows = (rows) => { + const list_store = Gtk.ListStore.new(rows[0].map(() => GObject.TYPE_STRING)); + + rows.forEach(columns => { + const iter = list_store.append(); + columns.forEach((c, i) => { + list_store.set_value(iter, i, c) + }); + }); + + return list_store; +}; + +const load_model = (model, cancellable, callback, progress_callback) => { + const istream = GGML.LanguageModel.stream_from_cache(model); + + if (progress_callback) { + istream.set_download_progress_callback(progress_callback); + } + + GGML.LanguageModel.load_defined_from_istream_async( + model, + istream, + cancellable, + (src, res) => { + try { + callback(GGML.LanguageModel.load_defined_from_istream_finish(res)); + } catch (e) { + if (e.code === Gio.IOErrorEnum.CANCELLED) { + return; + } + logError(e); + } + } + ); +}; + +const COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM = Object.keys(GGML.DefinedLanguageModel).map(k => GGML.DefinedLanguageModel[k]); + +class ModelLoader { + constructor() { + this._model_enum = null; + this._model = null; + this._pending_load = null; + } + + /** + * with_model: + * @model_enum: A #GGMLModelDescription + * @cancellable: A #GCancellable + * @callback: A callback to invoke once the model is done loading + * + * Does some action with a model. Also accepts a @cancellable - + * if the action is cancelled, then @callback won't be invoked, but + * the model will stil be downloaded if the download is in progress. + */ + with_model(model_enum, cancellable, callback, progress_callback) { + if (this._model_enum === model_enum) { + return callback(this._model) + } + + if (this._pending_load) { + /* We only do the most recent callback once the model is loaded + * and discard other ones */ + if (this._pending_load.model_enum !== model_enum) { + /* Cancel the existing pending load and start over again */ + this._pending_load.load_cancellable.cancel(); + } else { + /* Don't cancel the pending load operation, but change the callback */ + this._pending_load = { + model_enum: model_enum, + callback: callback, + load_cancellable: this._pending_load.load_cancellable, + action_cancellable: cancellable + }; + return; + } + } + + /* Create a pending load and load the model */ + this._pending_load = { + model_enum: model_enum, + callback: callback, + load_cancellable: new Gio.Cancellable(), + action_cancellable: cancellable + }; + + load_model(model_enum, this._pending_load.load_cancellable, model => { + const { callback, action_cancellable } = this._pending_load; + + if (action_cancellable === null || !action_cancellable.is_cancelled()) { + this._model_enum = model_enum; + this._model = model; + + System.gc(); + return callback(this._model); + } + }, progress_callback); + } +} + const LLMWriterAppMainWindow = GObject.registerClass({ Template: `${RESOURCE_PATH}/main.ui`, Children: [ 'content-view', - 'text-view' + 'text-view', + 'progress-bar' ] }, class LLMWriterAppMainWindow extends Gtk.ApplicationWindow { _init(params) { super._init(params); + this._model_loader = new ModelLoader(); + + const resetProgress = () => { + this.progress_bar.set_visible(false); + this.progress_bar.set_text("Starting Download"); + }; + const progressCallback = (received_bytes, total_bytes) => { + if (received_bytes === -1) { + resetProgress(); + return; + } + + const fraction = received_bytes / total_bytes; + + this.progress_bar.set_visible(true); + this.progress_bar.set_fraction(fraction); + this.progress_bar.set_text(`Downloading ${Math.trunc(fraction * 100)}%`); + }; + const header = new Gtk.HeaderBar({ visible: true, title: GLib.get_application_name(), @@ -57,10 +178,33 @@ const LLMWriterAppMainWindow = GObject.registerClass({ this._spinner = new Gtk.Spinner({ visible: true }); + const combobox = Gtk.ComboBox.new_with_model( + list_store_from_rows([ + ['GPT2 117M'], + ['GPT2 345M'], + ['GPT2 774M'], + ['GPT2 1558M'], + ]) + ); + const renderer = new Gtk.CellRendererText(); + combobox.pack_start(renderer, true); + combobox.add_attribute(renderer, 'text', 0); + combobox.set_active(0); + combobox.connect('changed', () => { + resetProgress(); + this._model_loader.with_model( + COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active], + null, + () => this._spinner.stop(), + progressCallback + ); + }); + combobox.show(); + + header.pack_start(combobox); header.pack_end(this._spinner); this.set_titlebar(header); - this._languageModel = null; this._textBufferState = STATE_TEXT_EDITOR; this._predictionsStartedAt = -1; this._cancellable = null; @@ -101,7 +245,6 @@ const LLMWriterAppMainWindow = GObject.registerClass({ if (currentPosition > 0 && currentPosition === this._lastCursorOffset && count > 0 && - this._languageModel !== null && this._textBufferState === STATE_TEXT_EDITOR) { const text = buffer.get_text( buffer.get_start_iter(), @@ -109,48 +252,58 @@ const LLMWriterAppMainWindow = GObject.registerClass({ false ); - this.text_view.set_editable(false); + /* Reset state immediately if the operation is cancelled */ this._cancellable = new Gio.Cancellable({}); + this._cancellable.connect(() => resetState()); + this._textBufferState = STATE_PREDICTING; this._candidateText = ''; this._spinner.start(); buffer.create_mark("predictions-start", buffer.get_end_iter(), true); - this._languageModel.complete_async( - text, - 10, - 2, + + this._model_loader.with_model( + COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active], this._cancellable, - (src, res) => { - let part, is_complete, is_complete_eos; - try { - [part, is_complete, is_complete_eos] = this._languageModel.complete_finish(res); - } catch (e) { - if (e.code == Gio.IOErrorEnum.CANCELLED) { - resetState(); + model => { + model.complete_async( + text, + 10, + 2, + this._cancellable, + (src, res) => { + let part, is_complete, is_complete_eos; + try { + [part, is_complete, is_complete_eos] = model.complete_finish(res); + } catch (e) { + if (e.code == Gio.IOErrorEnum.CANCELLED) { + return; + } + logError(e); + return; + } + + if (part === text) { + return; + } + + if (is_complete) { + this._cancellable = null; + this._textBufferState = STATE_WAITING; + this._spinner.stop(); + } + + this._candidateText += part; + const markup = `${GLib.markup_escape_text(part, part.length)}` + buffer.insert_markup(buffer.get_end_iter(), markup, markup.length); + System.gc(); } - return; - } - - if (part === text) { - return; - } - - if (is_complete) { - this._cancellable = null; - this._textBufferState = STATE_WAITING; - this._spinner.stop(); - } - - this._candidateText += part; - const markup = `${GLib.markup_escape_text(part, part.length)}` - buffer.insert_markup(buffer.get_end_iter(), markup, markup.length); - System.gc(); - } + ); + }, + progressCallback ); } else if (currentPosition > 0 && currentPosition === this._lastCursorOffset && count > 0 && - this._languageModel !== null && this._textBufferState === STATE_WAITING) { // Delete the gray text and substitute the real text. removePredictedText(); @@ -184,18 +337,6 @@ const LLMWriterAppMainWindow = GObject.registerClass({ } vfunc_show() { - this._spinner.start(); - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); - GGML.LanguageModel.load_defined_from_istream_async( - GGML.DefinedLanguageModel.GPT2, - istream, - null, - (src, res) => { - this._languageModel = GGML.LanguageModel.load_defined_from_istream_finish(res); - this._spinner.stop(); - } - ); - super.vfunc_show(); } }); diff --git a/ggml-gobject/ggml-cached-model.c b/ggml-gobject/ggml-cached-model.c index 325cbfd..c21de7f 100644 --- a/ggml-gobject/ggml-cached-model.c +++ b/ggml-gobject/ggml-cached-model.c @@ -24,6 +24,8 @@ #include #include #include +#include +#include struct _GGMLCachedModelIstream { @@ -36,6 +38,11 @@ typedef struct size_t remote_content_length; char *remote_url; char *local_path; + uint32_t progress_indicator_source_id; + GAsyncQueue *progress_indicator_queue; + GFileProgressCallback progress_callback; + gpointer progress_callback_data; + GDestroyNotify progress_callback_data_destroy; } GGMLCachedModelIstreamPrivate; enum { @@ -50,10 +57,91 @@ G_DEFINE_TYPE_WITH_CODE (GGMLCachedModelIstream, G_TYPE_FILE_INPUT_STREAM, G_ADD_PRIVATE (GGMLCachedModelIstream)) +static gboolean +ggml_download_progress_async_queue_monitor_callback (gpointer message, + gpointer user_data) +{ + GGMLCachedModelIstream *cached_model = user_data; + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + /* We progress the message by sending it to the progress callback */ + goffset progressed_bytes = GPOINTER_TO_INT (message); + + /* We unconditionally send to the progress callback, even if it is + * the sentinel value - the progress callback has to be able to handle this */ + (*priv->progress_callback) (progressed_bytes, + (goffset) priv->remote_content_length, + priv->progress_callback_data); + + + /* Sentinel message */ + if (progressed_bytes == -1) + { + return G_SOURCE_REMOVE; + } + + return G_SOURCE_CONTINUE; +} + +static void +ggml_cached_model_push_download_progress_to_queue (goffset progressed_bytes, + goffset total_bytes, + gpointer user_data) +{ + GGMLCachedModelIstream *cached_model = user_data; + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + /* Only push a message if one hasn't been consumed yet - otherwise + * we run the risk of flooding the main progress */ + if (g_async_queue_length (priv->progress_indicator_queue) == 0) + { + g_async_queue_push (priv->progress_indicator_queue, GINT_TO_POINTER (progressed_bytes)); + g_main_context_wakeup (g_main_context_default ()); + } +} + +/** + * ggml_cached_model_istream_set_download_progress_callback: + * @callback: A #GFileProgressCallback with progress about the download operation. + * @user_data: (closure callback): A closure for @callback + * @user_data_destroy: (destroy callback): A #GDestroyNotify for @callback + * + * Set a progress-monitor callback for @cached_model, which will be called with + * download progress in case a model is being downloaded. The application can use + * the callback to update state, for example a progress bar. + * + * This function should be called from the main thread. It will handle situations where + * the download IO operation happens on a separate thread. + */ +void +ggml_cached_model_istream_set_download_progress_callback (GGMLCachedModelIstream *cached_model, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy) +{ + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + if (priv->progress_indicator_source_id != 0) + { + g_source_remove (priv->progress_indicator_source_id); + priv->progress_indicator_source_id = 0; + } + + g_clear_pointer (&priv->progress_indicator_queue, g_async_queue_unref); + + priv->progress_callback = callback; + priv->progress_callback_data = user_data; + priv->progress_callback_data_destroy = user_data_destroy; +} + static GFileInputStream * -ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, - GCancellable *cancellable, - GError **error) +ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, + GFileProgressCallback progress_callback, + gpointer progress_callback_data, + GCancellable *cancellable, + GError **error) { GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); @@ -104,16 +192,53 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, SoupMessageHeaders *response_headers = soup_message_get_response_headers (message); priv->remote_content_length = soup_message_headers_get_content_length (response_headers); - if (!g_output_stream_splice (G_OUTPUT_STREAM (output_stream), - in_stream, - G_OUTPUT_STREAM_SPLICE_CLOSE_SOURCE | - G_OUTPUT_STREAM_SPLICE_CLOSE_TARGET, - cancellable, - error)) + g_autoptr(GGMLProgressIstream) progress_istream = ggml_progress_istream_new (in_stream, + priv->remote_content_length); + + if (priv->progress_callback != NULL) + { + priv->progress_indicator_queue = g_async_queue_new (); + + GSource *monitor_source = ggml_async_queue_source_new (priv->progress_indicator_queue, + ggml_download_progress_async_queue_monitor_callback, + g_object_ref (cached_model), + (GDestroyNotify) g_object_unref, + cancellable); + g_source_attach (g_steal_pointer (&monitor_source), NULL); + + ggml_progress_istream_set_callback (progress_istream, + ggml_cached_model_push_download_progress_to_queue, + g_object_ref (cached_model), + g_object_unref); + } + + if (g_output_stream_splice (G_OUTPUT_STREAM (output_stream), + G_INPUT_STREAM (progress_istream), + G_OUTPUT_STREAM_SPLICE_CLOSE_SOURCE | + G_OUTPUT_STREAM_SPLICE_CLOSE_TARGET, + cancellable, + error) == -1) { + /* We send the sentinel value to the progress callback on the error + * case too, so that it can clean up */ + if (priv->progress_callback != NULL) + { + ggml_cached_model_push_download_progress_to_queue (-1, + priv->remote_content_length, + cached_model); + } + return NULL; } + /* Once we're done, send the sentinel message to the queue */ + if (priv->progress_callback != NULL) + { + ggml_cached_model_push_download_progress_to_queue (-1, + priv->remote_content_length, + cached_model); + } + /* After that, we have to move the temporary file into the right place. */ g_autoptr(GFile) output_directory = g_file_get_parent (local_file); @@ -139,25 +264,43 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, } /* We call the same function again, now that the cached file is in place. */ - return ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + return ggml_cached_model_istream_ensure_stream (cached_model, + progress_callback, + progress_callback_data, + cancellable, + error); } static void -ggml_cached_model_istream_finalize (GObject *object) +ggml_cached_model_istream_dispose (GObject *object) { GGMLCachedModelIstream *cached_model = GGML_CACHED_MODEL_ISTREAM (object); GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); - g_clear_pointer (&priv->local_path, g_free); - g_clear_pointer (&priv->remote_url, g_free); + + g_clear_pointer (&priv->progress_indicator_queue, g_async_queue_unref); + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + /* If for some reason the source is still there, drop it */ + if (priv->progress_indicator_source_id != 0) + { + g_source_remove (priv->progress_indicator_source_id); + priv->progress_indicator_source_id = 0; + } G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->finalize (object); } static void -ggml_cached_model_istream_dispose (GObject *object) +ggml_cached_model_istream_finalize (GObject *object) { - G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->dispose (object); + GGMLCachedModelIstream *cached_model = GGML_CACHED_MODEL_ISTREAM (object); + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + g_clear_pointer (&priv->local_path, g_free); + g_clear_pointer (&priv->remote_url, g_free); + + G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->finalize (object); } static void @@ -220,7 +363,11 @@ ggml_cached_model_istream_read_fn (GInputStream *stream, if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -242,7 +389,11 @@ ggml_cached_model_istream_skip (GInputStream *stream, if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -304,7 +455,11 @@ ggml_cached_model_istream_query_info (GFileInputStream *stream, * but that's the only way to get what the user is asking for. */ if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -364,12 +519,12 @@ ggml_cached_model_istream_class_init (GGMLCachedModelIstreamClass *klass) G_PARAM_CONSTRUCT)); } -GFileInputStream * +GGMLCachedModelIstream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path) { - return G_FILE_INPUT_STREAM (g_object_new (GGML_TYPE_CACHED_MODEL_ISTREAM, - "remote-url", remote_url, - "local-path", local_path, - NULL)); -} \ No newline at end of file + return GGML_CACHED_MODEL_ISTREAM (g_object_new (GGML_TYPE_CACHED_MODEL_ISTREAM, + "remote-url", remote_url, + "local-path", local_path, + NULL)); +} diff --git a/ggml-gobject/ggml-cached-model.h b/ggml-gobject/ggml-cached-model.h index 7cf804d..0ad3100 100644 --- a/ggml-gobject/ggml-cached-model.h +++ b/ggml-gobject/ggml-cached-model.h @@ -28,6 +28,10 @@ G_BEGIN_DECLS #define GGML_TYPE_CACHED_MODEL_ISTREAM (ggml_cached_model_istream_get_type ()) G_DECLARE_FINAL_TYPE (GGMLCachedModelIstream, ggml_cached_model_istream, GGML, CACHED_MODEL_ISTREAM, GFileInputStream) -GFileInputStream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path); +GGMLCachedModelIstream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path); +void ggml_cached_model_istream_set_download_progress_callback (GGMLCachedModelIstream *cached_model, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy); -G_END_DECLS \ No newline at end of file +G_END_DECLS diff --git a/ggml-gobject/ggml-language-model.c b/ggml-gobject/ggml-language-model.c index 7e1b6df..5b41c04 100644 --- a/ggml-gobject/ggml-language-model.c +++ b/ggml-gobject/ggml-language-model.c @@ -945,7 +945,22 @@ static struct GGMLLanguageModelDefinitions { GGMLModelDescFromHyperparametersFunc model_desc_from_hyperparameters_func; GGMLModelForwardFunc forward_func; } ggml_language_model_definitions[] = { - /* GGML_DEFINED_MODEL_GPT2 */ + /* GGML_DEFINED_MODEL_GPT2P117M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P345M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P774M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P1558M */ { .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, .forward_func = ggml_gpt_model_forward_pass @@ -1284,7 +1299,10 @@ ggml_language_model_load_defined_from_istream_async (GGMLDefinedLanguageModel } static const char *ggml_language_model_urls[] = { - "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin" + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-345M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-774M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-1558M.bin" }; #define GGML_GOBJECT_MODELS_VERSION "0" @@ -1294,12 +1312,12 @@ static const char *ggml_language_model_urls[] = { * @defined_model: A #GGMLDefinedLanguageModel * @error: A #GError * - * Creates a new #GFileInputStream which will either download the model upon the first + * Creates a new #GGMLCachedModelIstream which will either download the model upon the first * read, or return a cached version from the disk. * - * Returns: (transfer full): A #GFileInputStream on success, %NULL with @error set on failure. + * Returns: (transfer full): A #GGMLCachedModelIstream on success, %NULL with @error set on failure. */ -GFileInputStream * +GGMLCachedModelIstream * ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, GError **error) { diff --git a/ggml-gobject/ggml-language-model.h b/ggml-gobject/ggml-language-model.h index a9ab5ba..5db7816 100644 --- a/ggml-gobject/ggml-language-model.h +++ b/ggml-gobject/ggml-language-model.h @@ -53,7 +53,10 @@ void ggml_language_model_consume_istream_magic_async (GInputStream *istr gpointer user_data); typedef enum { - GGML_DEFINED_LANGUAGE_MODEL_GPT2 + GGML_DEFINED_LANGUAGE_MODEL_GPT2P117M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P345M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P774M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P1558M, } GGMLDefinedLanguageModel; GGMLLanguageModel *ggml_language_model_load_from_istream (GInputStream *istream, @@ -91,8 +94,8 @@ void ggml_language_model_load_defined_from_istream_async (GGMLDefinedLanguageMod GGMLLanguageModel *ggml_language_model_load_defined_from_istream_finish (GAsyncResult *result, GError **error); -GFileInputStream *ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, - GError **error); +GGMLCachedModelIstream *ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, + GError **error); char * ggml_language_model_complete (GGMLLanguageModel *language_model, const char *prompt, diff --git a/ggml-gobject/internal/ggml-progress-istream.c b/ggml-gobject/internal/ggml-progress-istream.c new file mode 100644 index 0000000..4fcf6bd --- /dev/null +++ b/ggml-gobject/internal/ggml-progress-istream.c @@ -0,0 +1,212 @@ +/* + * ggml-gobject/ggml-progress-istream.c + * + * Library code for ggml-progress-istream + * + * Copyright (C) 2023 Sam Spilsbury. + * + * ggml-gobject is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * ggml-gobject is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License along + * with ggml-gobject; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include +#include +#include + +struct _GGMLProgressIstream +{ + GFilterInputStream parent_instance; +}; + +typedef struct +{ + size_t bytes_consumed; + size_t expected_size; + GFileProgressCallback progress_callback; + gpointer progress_callback_data; + GDestroyNotify progress_callback_data_destroy; +} GGMLProgressIstreamPrivate; + +enum { + PROP_0, + PROP_EXPECTED_SIZE, + PROP_N +}; + +G_DEFINE_TYPE_WITH_CODE (GGMLProgressIstream, + ggml_progress_istream, + G_TYPE_FILTER_INPUT_STREAM, + G_ADD_PRIVATE (GGMLProgressIstream)) + +void +ggml_progress_istream_set_callback (GGMLProgressIstream *istream, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy) +{ + GGMLProgressIstream *progress_istream = istream; + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + priv->progress_callback = callback; + priv->progress_callback_data = user_data; + priv->progress_callback_data_destroy = user_data_destroy; +} + +static void +ggml_progress_istream_dispose (GObject *object) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + G_OBJECT_CLASS (ggml_progress_istream_parent_class)->dispose (object); +} + +static void +ggml_progress_istream_get_property (GObject *object, + uint32_t property_id, + GValue *value, + GParamSpec *pspec) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + switch (property_id) + { + case PROP_EXPECTED_SIZE: + g_value_set_uint (value, priv->expected_size); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); + break; + } +} + +static void +ggml_progress_istream_set_property (GObject *object, + uint32_t property_id, + const GValue *value, + GParamSpec *pspec) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + switch (property_id) + { + case PROP_EXPECTED_SIZE: + priv->expected_size = g_value_get_uint (value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); + break; + } +} + +static ssize_t +ggml_progress_istream_read_fn (GInputStream *stream, + void *buffer, + gsize count, + GCancellable *cancellable, + GError **error) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (stream); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + ssize_t result = G_INPUT_STREAM_CLASS (ggml_progress_istream_parent_class)->read_fn (stream, buffer, count, cancellable, error); + + /* Report back to the progress callback that we read result bytes */ + if (result != -1) + { + priv->bytes_consumed += result; + + if (priv->progress_callback != NULL) + { + (*priv->progress_callback) (priv->bytes_consumed, + priv->expected_size, + priv->progress_callback_data); + } + } + + return result; +} + +static ssize_t +ggml_progress_istream_skip (GInputStream *stream, + size_t count, + GCancellable *cancellable, + GError **error) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (stream); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + ssize_t result = G_INPUT_STREAM_CLASS (ggml_progress_istream_parent_class)->skip (stream, count, cancellable, error); + + /* Report back to the progress callback that we read result bytes */ + if (result != -1) + { + priv->bytes_consumed += result; + + if (priv->progress_callback != NULL) + { + (*priv->progress_callback) (priv->bytes_consumed, + priv->expected_size, + priv->progress_callback_data); + } + } + + return result; +} + +static void +ggml_progress_istream_init (GGMLProgressIstream *progress_istream) +{ +} + +static void +ggml_progress_istream_class_init (GGMLProgressIstreamClass *klass) +{ + GObjectClass *object_class = G_OBJECT_CLASS (klass); + GInputStreamClass *stream_class = G_INPUT_STREAM_CLASS (klass); + + object_class->dispose = ggml_progress_istream_dispose; + object_class->set_property = ggml_progress_istream_set_property; + object_class->get_property = ggml_progress_istream_get_property; + + stream_class->read_fn = ggml_progress_istream_read_fn; + stream_class->skip = ggml_progress_istream_skip; + + g_object_class_install_property (object_class, + PROP_EXPECTED_SIZE, + g_param_spec_uint ("expected-size", + "Expected Size", + "Expected Size", + 1, + G_MAXUINT, + 1, + G_PARAM_READWRITE | + G_PARAM_CONSTRUCT)); +} + +GGMLProgressIstream * +ggml_progress_istream_new (GInputStream *base_stream, + size_t expected_size) +{ + return GGML_PROGRESS_ISTREAM (g_object_new (GGML_TYPE_PROGRESS_ISTREAM, + "base-stream", base_stream, + "expected-size", expected_size, + NULL)); +} diff --git a/ggml-gobject/internal/ggml-progress-istream.h b/ggml-gobject/internal/ggml-progress-istream.h new file mode 100644 index 0000000..6e40885 --- /dev/null +++ b/ggml-gobject/internal/ggml-progress-istream.h @@ -0,0 +1,38 @@ +/* + * ggml-gobject/ggml-progress-istream.h + * + * Library code for ggml-progress-istream + * + * Copyright (C) 2023 Sam Spilsbury. + * + * ggml-gobject is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * ggml-gobject is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License along + * with ggml-gobject; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include +#include + +G_BEGIN_DECLS + +#define GGML_TYPE_PROGRESS_ISTREAM (ggml_progress_istream_get_type ()) +G_DECLARE_FINAL_TYPE (GGMLProgressIstream, ggml_progress_istream, GGML, PROGRESS_ISTREAM, GFilterInputStream) + +GGMLProgressIstream * ggml_progress_istream_new (GInputStream *base_stream, + size_t expected_length); +void ggml_progress_istream_set_callback (GGMLProgressIstream *progress_istream, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy); + +G_END_DECLS diff --git a/ggml-gobject/meson.build b/ggml-gobject/meson.build index 340e1fe..6298769 100644 --- a/ggml-gobject/meson.build +++ b/ggml-gobject/meson.build @@ -37,11 +37,13 @@ ggml_gobject_toplevel_introspectable_sources = files([ ]) ggml_gobject_toplevel_internal_sources = files([ 'internal/ggml-async-queue-source.c', + 'internal/ggml-progress-istream.c', 'internal/ggml-stream-internal.c', ]) ggml_gobject_toplevel_internal_headers = files([ 'internal/ggml-async-queue-source.h', 'internal/ggml-context-internal.h', + 'internal/ggml-progress-istream.h', 'internal/ggml-stream-internal.h', 'internal/ggml-tensor-internal.h', ]) diff --git a/tests/js/testLoadGPT2.js b/tests/js/testLoadGPT2.js index f90ce38..22a22f8 100644 --- a/tests/js/testLoadGPT2.js +++ b/tests/js/testLoadGPT2.js @@ -371,7 +371,7 @@ describe('GGML GPT2', function() { const model_desc = createModelDescGPT2(n_inp, d_model, d_ff, n_layers, n_ctx); }); it('can load the GPT2 weights from a bin file', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P117M); const language_model = GGML.LanguageModel.load_from_istream( istream, @@ -387,7 +387,7 @@ describe('GGML GPT2', function() { ); }); it('can load the GPT2 weights from a bin file asynchronously', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P117M); GGML.LanguageModel.load_from_istream_async( istream, @@ -407,10 +407,10 @@ describe('GGML GPT2', function() { ); }); it('can load the GPT2 weights from a bin file asynchronously (defined)', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); GGML.LanguageModel.load_defined_from_istream_async( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null, (src, res) => { @@ -420,10 +420,10 @@ describe('GGML GPT2', function() { ); }); it('can do a forward pass through some data', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -433,10 +433,10 @@ describe('GGML GPT2', function() { ); }); it('can handle cancellation', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -448,10 +448,10 @@ describe('GGML GPT2', function() { expect(() => language_model.complete('The meaning of life is:', 7, cancellable)).toThrow(); }); it('can do a forward pass through some data asynchronously', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -478,10 +478,10 @@ describe('GGML GPT2', function() { thread.join(); }); it('can handle cancellation on asynchronous completions', (done) => { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -512,7 +512,7 @@ describe('GGML GPT2', function() { thread.join(); }); it('can do a forward pass defined in JS through some data', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_from_istream( istream,