From 6751aba55635cba0e67db1c8ce86da4bbfe65835 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Fri, 21 Jul 2023 23:53:35 +0300 Subject: [PATCH 1/6] ggml-gobject: Add GGMLProgressIstream This is a helper class which can be used to monitor the read progress of an input stream, so that the user can see how long an operation is taking. --- ggml-gobject/internal/ggml-progress-istream.c | 212 ++++++++++++++++++ ggml-gobject/internal/ggml-progress-istream.h | 38 ++++ ggml-gobject/meson.build | 2 + 3 files changed, 252 insertions(+) create mode 100644 ggml-gobject/internal/ggml-progress-istream.c create mode 100644 ggml-gobject/internal/ggml-progress-istream.h diff --git a/ggml-gobject/internal/ggml-progress-istream.c b/ggml-gobject/internal/ggml-progress-istream.c new file mode 100644 index 0000000..4fcf6bd --- /dev/null +++ b/ggml-gobject/internal/ggml-progress-istream.c @@ -0,0 +1,212 @@ +/* + * ggml-gobject/ggml-progress-istream.c + * + * Library code for ggml-progress-istream + * + * Copyright (C) 2023 Sam Spilsbury. + * + * ggml-gobject is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * ggml-gobject is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License along + * with ggml-gobject; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include +#include +#include + +struct _GGMLProgressIstream +{ + GFilterInputStream parent_instance; +}; + +typedef struct +{ + size_t bytes_consumed; + size_t expected_size; + GFileProgressCallback progress_callback; + gpointer progress_callback_data; + GDestroyNotify progress_callback_data_destroy; +} GGMLProgressIstreamPrivate; + +enum { + PROP_0, + PROP_EXPECTED_SIZE, + PROP_N +}; + +G_DEFINE_TYPE_WITH_CODE (GGMLProgressIstream, + ggml_progress_istream, + G_TYPE_FILTER_INPUT_STREAM, + G_ADD_PRIVATE (GGMLProgressIstream)) + +void +ggml_progress_istream_set_callback (GGMLProgressIstream *istream, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy) +{ + GGMLProgressIstream *progress_istream = istream; + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + priv->progress_callback = callback; + priv->progress_callback_data = user_data; + priv->progress_callback_data_destroy = user_data_destroy; +} + +static void +ggml_progress_istream_dispose (GObject *object) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + G_OBJECT_CLASS (ggml_progress_istream_parent_class)->dispose (object); +} + +static void +ggml_progress_istream_get_property (GObject *object, + uint32_t property_id, + GValue *value, + GParamSpec *pspec) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + switch (property_id) + { + case PROP_EXPECTED_SIZE: + g_value_set_uint (value, priv->expected_size); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); + break; + } +} + +static void +ggml_progress_istream_set_property (GObject *object, + uint32_t property_id, + const GValue *value, + GParamSpec *pspec) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (object); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + switch (property_id) + { + case PROP_EXPECTED_SIZE: + priv->expected_size = g_value_get_uint (value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID (object, property_id, pspec); + break; + } +} + +static ssize_t +ggml_progress_istream_read_fn (GInputStream *stream, + void *buffer, + gsize count, + GCancellable *cancellable, + GError **error) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (stream); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + ssize_t result = G_INPUT_STREAM_CLASS (ggml_progress_istream_parent_class)->read_fn (stream, buffer, count, cancellable, error); + + /* Report back to the progress callback that we read result bytes */ + if (result != -1) + { + priv->bytes_consumed += result; + + if (priv->progress_callback != NULL) + { + (*priv->progress_callback) (priv->bytes_consumed, + priv->expected_size, + priv->progress_callback_data); + } + } + + return result; +} + +static ssize_t +ggml_progress_istream_skip (GInputStream *stream, + size_t count, + GCancellable *cancellable, + GError **error) +{ + GGMLProgressIstream *progress_istream = GGML_PROGRESS_ISTREAM (stream); + GGMLProgressIstreamPrivate *priv = ggml_progress_istream_get_instance_private (progress_istream); + + ssize_t result = G_INPUT_STREAM_CLASS (ggml_progress_istream_parent_class)->skip (stream, count, cancellable, error); + + /* Report back to the progress callback that we read result bytes */ + if (result != -1) + { + priv->bytes_consumed += result; + + if (priv->progress_callback != NULL) + { + (*priv->progress_callback) (priv->bytes_consumed, + priv->expected_size, + priv->progress_callback_data); + } + } + + return result; +} + +static void +ggml_progress_istream_init (GGMLProgressIstream *progress_istream) +{ +} + +static void +ggml_progress_istream_class_init (GGMLProgressIstreamClass *klass) +{ + GObjectClass *object_class = G_OBJECT_CLASS (klass); + GInputStreamClass *stream_class = G_INPUT_STREAM_CLASS (klass); + + object_class->dispose = ggml_progress_istream_dispose; + object_class->set_property = ggml_progress_istream_set_property; + object_class->get_property = ggml_progress_istream_get_property; + + stream_class->read_fn = ggml_progress_istream_read_fn; + stream_class->skip = ggml_progress_istream_skip; + + g_object_class_install_property (object_class, + PROP_EXPECTED_SIZE, + g_param_spec_uint ("expected-size", + "Expected Size", + "Expected Size", + 1, + G_MAXUINT, + 1, + G_PARAM_READWRITE | + G_PARAM_CONSTRUCT)); +} + +GGMLProgressIstream * +ggml_progress_istream_new (GInputStream *base_stream, + size_t expected_size) +{ + return GGML_PROGRESS_ISTREAM (g_object_new (GGML_TYPE_PROGRESS_ISTREAM, + "base-stream", base_stream, + "expected-size", expected_size, + NULL)); +} diff --git a/ggml-gobject/internal/ggml-progress-istream.h b/ggml-gobject/internal/ggml-progress-istream.h new file mode 100644 index 0000000..6e40885 --- /dev/null +++ b/ggml-gobject/internal/ggml-progress-istream.h @@ -0,0 +1,38 @@ +/* + * ggml-gobject/ggml-progress-istream.h + * + * Library code for ggml-progress-istream + * + * Copyright (C) 2023 Sam Spilsbury. + * + * ggml-gobject is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * ggml-gobject is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License along + * with ggml-gobject; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include +#include + +G_BEGIN_DECLS + +#define GGML_TYPE_PROGRESS_ISTREAM (ggml_progress_istream_get_type ()) +G_DECLARE_FINAL_TYPE (GGMLProgressIstream, ggml_progress_istream, GGML, PROGRESS_ISTREAM, GFilterInputStream) + +GGMLProgressIstream * ggml_progress_istream_new (GInputStream *base_stream, + size_t expected_length); +void ggml_progress_istream_set_callback (GGMLProgressIstream *progress_istream, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy); + +G_END_DECLS diff --git a/ggml-gobject/meson.build b/ggml-gobject/meson.build index 340e1fe..6298769 100644 --- a/ggml-gobject/meson.build +++ b/ggml-gobject/meson.build @@ -37,11 +37,13 @@ ggml_gobject_toplevel_introspectable_sources = files([ ]) ggml_gobject_toplevel_internal_sources = files([ 'internal/ggml-async-queue-source.c', + 'internal/ggml-progress-istream.c', 'internal/ggml-stream-internal.c', ]) ggml_gobject_toplevel_internal_headers = files([ 'internal/ggml-async-queue-source.h', 'internal/ggml-context-internal.h', + 'internal/ggml-progress-istream.h', 'internal/ggml-stream-internal.h', 'internal/ggml-tensor-internal.h', ]) From 37e7adccb594a4ae66eec88fb9565c971d298b47 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Fri, 21 Jul 2023 23:55:11 +0300 Subject: [PATCH 2/6] ggml-cached-model: Add ggml_cached_model_istream_set_download_progress_callback This uses GGMLProgressIstream to monitor the progress of the GInputStream for the download, so that it can be displayed in a friendly way to the user (instead of the read operation just stalling forever). It also means that ggml_cached_model_istream_new now returns a GGMLCachedModelIstream as opposed to a GFileInputStream, so that the user can invoke the set_download_progress_callback method. --- ggml-gobject/ggml-cached-model.c | 184 +++++++++++++++++++++++++++---- ggml-gobject/ggml-cached-model.h | 8 +- 2 files changed, 171 insertions(+), 21 deletions(-) diff --git a/ggml-gobject/ggml-cached-model.c b/ggml-gobject/ggml-cached-model.c index 325cbfd..6d6b420 100644 --- a/ggml-gobject/ggml-cached-model.c +++ b/ggml-gobject/ggml-cached-model.c @@ -24,6 +24,8 @@ #include #include #include +#include +#include struct _GGMLCachedModelIstream { @@ -36,6 +38,11 @@ typedef struct size_t remote_content_length; char *remote_url; char *local_path; + uint32_t progress_indicator_source_id; + GAsyncQueue *progress_indicator_queue; + GFileProgressCallback progress_callback; + gpointer progress_callback_data; + GDestroyNotify progress_callback_data_destroy; } GGMLCachedModelIstreamPrivate; enum { @@ -50,10 +57,91 @@ G_DEFINE_TYPE_WITH_CODE (GGMLCachedModelIstream, G_TYPE_FILE_INPUT_STREAM, G_ADD_PRIVATE (GGMLCachedModelIstream)) +static gboolean +ggml_download_progress_async_queue_monitor_callback (gpointer message, + gpointer user_data) +{ + GGMLCachedModelIstream *cached_model = user_data; + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + /* We progress the message by sending it to the progress callback */ + goffset progressed_bytes = GPOINTER_TO_INT (message); + + /* We unconditionally send to the progress callback, even if it is + * the sentinel value - the progress callback has to be able to handle this */ + (*priv->progress_callback) (progressed_bytes, + (goffset) priv->remote_content_length, + priv->progress_callback_data); + + + /* Sentinel message */ + if (progressed_bytes == -1) + { + return G_SOURCE_REMOVE; + } + + return G_SOURCE_CONTINUE; +} + +static void +ggml_cached_model_push_download_progress_to_queue (goffset progressed_bytes, + goffset total_bytes, + gpointer user_data) +{ + GGMLCachedModelIstream *cached_model = user_data; + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + /* Only push a message if one hasn't been consumed yet - otherwise + * we run the risk of flooding the main progress */ + if (g_async_queue_length (priv->progress_indicator_queue) == 0) + { + g_async_queue_push (priv->progress_indicator_queue, GINT_TO_POINTER (progressed_bytes)); + g_main_context_wakeup (g_main_context_default ()); + } +} + +/** + * ggml_cached_model_istream_set_download_progress_callback: + * @callback: A #GFileProgressCallback with progress about the download operation. + * @user_data: (closure callback): A closure for @callback + * @user_data_destroy: (destroy callback): A #GDestroyNotify for @callback + * + * Set a progress-monitor callback for @cached_model, which will be called with + * download progress in case a model is being downloaded. The application can use + * the callback to update state, for example a progress bar. + * + * This function should be called from the main thread. It will handle situations where + * the download IO operation happens on a separate thread. + */ +void +ggml_cached_model_istream_set_download_progress_callback (GGMLCachedModelIstream *cached_model, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy) +{ + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + if (priv->progress_indicator_source_id != 0) + { + g_source_remove (priv->progress_indicator_source_id); + priv->progress_indicator_source_id = 0; + } + + g_clear_pointer (&priv->progress_indicator_queue, g_async_queue_unref); + + priv->progress_callback = callback; + priv->progress_callback_data = user_data; + priv->progress_callback_data_destroy = user_data_destroy; +} + static GFileInputStream * -ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, - GCancellable *cancellable, - GError **error) +ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, + GFileProgressCallback progress_callback, + gpointer progress_callback_data, + GCancellable *cancellable, + GError **error) { GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); @@ -104,8 +192,28 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, SoupMessageHeaders *response_headers = soup_message_get_response_headers (message); priv->remote_content_length = soup_message_headers_get_content_length (response_headers); + g_autoptr(GGMLProgressIstream) progress_istream = ggml_progress_istream_new (in_stream, + priv->remote_content_length); + + if (priv->progress_callback != NULL) + { + priv->progress_indicator_queue = g_async_queue_new (); + + GSource *monitor_source = ggml_async_queue_source_new (priv->progress_indicator_queue, + ggml_download_progress_async_queue_monitor_callback, + g_object_ref (cached_model), + (GDestroyNotify) g_object_unref, + cancellable); + g_source_attach (g_steal_pointer (&monitor_source), NULL); + + ggml_progress_istream_set_callback (progress_istream, + ggml_cached_model_push_download_progress_to_queue, + g_object_ref (cached_model), + g_object_unref); + } + if (!g_output_stream_splice (G_OUTPUT_STREAM (output_stream), - in_stream, + G_INPUT_STREAM (progress_istream), G_OUTPUT_STREAM_SPLICE_CLOSE_SOURCE | G_OUTPUT_STREAM_SPLICE_CLOSE_TARGET, cancellable, @@ -114,6 +222,14 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, return NULL; } + /* Once we're done, send the sentinel message to the queue */ + if (priv->progress_callback != NULL) + { + ggml_cached_model_push_download_progress_to_queue (-1, + priv->remote_content_length, + cached_model); + } + /* After that, we have to move the temporary file into the right place. */ g_autoptr(GFile) output_directory = g_file_get_parent (local_file); @@ -139,25 +255,43 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_model, } /* We call the same function again, now that the cached file is in place. */ - return ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + return ggml_cached_model_istream_ensure_stream (cached_model, + progress_callback, + progress_callback_data, + cancellable, + error); } static void -ggml_cached_model_istream_finalize (GObject *object) +ggml_cached_model_istream_dispose (GObject *object) { GGMLCachedModelIstream *cached_model = GGML_CACHED_MODEL_ISTREAM (object); GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); - g_clear_pointer (&priv->local_path, g_free); - g_clear_pointer (&priv->remote_url, g_free); + + g_clear_pointer (&priv->progress_indicator_queue, g_async_queue_unref); + g_clear_pointer (&priv->progress_callback_data, priv->progress_callback_data_destroy); + + /* If for some reason the source is still there, drop it */ + if (priv->progress_indicator_source_id != 0) + { + g_source_remove (priv->progress_indicator_source_id); + priv->progress_indicator_source_id = 0; + } G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->finalize (object); } static void -ggml_cached_model_istream_dispose (GObject *object) +ggml_cached_model_istream_finalize (GObject *object) { - G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->dispose (object); + GGMLCachedModelIstream *cached_model = GGML_CACHED_MODEL_ISTREAM (object); + GGMLCachedModelIstreamPrivate *priv = ggml_cached_model_istream_get_instance_private (cached_model); + + g_clear_pointer (&priv->local_path, g_free); + g_clear_pointer (&priv->remote_url, g_free); + + G_OBJECT_CLASS (ggml_cached_model_istream_parent_class)->finalize (object); } static void @@ -220,7 +354,11 @@ ggml_cached_model_istream_read_fn (GInputStream *stream, if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -242,7 +380,11 @@ ggml_cached_model_istream_skip (GInputStream *stream, if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -304,7 +446,11 @@ ggml_cached_model_istream_query_info (GFileInputStream *stream, * but that's the only way to get what the user is asking for. */ if (priv->current_stream == NULL) { - priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, cancellable, error); + priv->current_stream = ggml_cached_model_istream_ensure_stream (cached_model, + priv->progress_callback, + priv->progress_callback_data, + cancellable, + error); if (priv->current_stream == NULL) { @@ -364,12 +510,12 @@ ggml_cached_model_istream_class_init (GGMLCachedModelIstreamClass *klass) G_PARAM_CONSTRUCT)); } -GFileInputStream * +GGMLCachedModelIstream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path) { - return G_FILE_INPUT_STREAM (g_object_new (GGML_TYPE_CACHED_MODEL_ISTREAM, - "remote-url", remote_url, - "local-path", local_path, - NULL)); -} \ No newline at end of file + return GGML_CACHED_MODEL_ISTREAM (g_object_new (GGML_TYPE_CACHED_MODEL_ISTREAM, + "remote-url", remote_url, + "local-path", local_path, + NULL)); +} diff --git a/ggml-gobject/ggml-cached-model.h b/ggml-gobject/ggml-cached-model.h index 7cf804d..0ad3100 100644 --- a/ggml-gobject/ggml-cached-model.h +++ b/ggml-gobject/ggml-cached-model.h @@ -28,6 +28,10 @@ G_BEGIN_DECLS #define GGML_TYPE_CACHED_MODEL_ISTREAM (ggml_cached_model_istream_get_type ()) G_DECLARE_FINAL_TYPE (GGMLCachedModelIstream, ggml_cached_model_istream, GGML, CACHED_MODEL_ISTREAM, GFileInputStream) -GFileInputStream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path); +GGMLCachedModelIstream * ggml_cached_model_istream_new (const char *remote_url, const char *local_path); +void ggml_cached_model_istream_set_download_progress_callback (GGMLCachedModelIstream *cached_model, + GFileProgressCallback callback, + gpointer user_data, + GDestroyNotify user_data_destroy); -G_END_DECLS \ No newline at end of file +G_END_DECLS From f1b5ab534be32514309ba79674c80ce1f3901cb9 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Fri, 21 Jul 2023 23:57:24 +0300 Subject: [PATCH 3/6] ggml-language-model: stream_from_cache now returns GGMLCachedModelIstream --- ggml-gobject/ggml-language-model.c | 6 +++--- ggml-gobject/ggml-language-model.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-gobject/ggml-language-model.c b/ggml-gobject/ggml-language-model.c index 7e1b6df..2111a70 100644 --- a/ggml-gobject/ggml-language-model.c +++ b/ggml-gobject/ggml-language-model.c @@ -1294,12 +1294,12 @@ static const char *ggml_language_model_urls[] = { * @defined_model: A #GGMLDefinedLanguageModel * @error: A #GError * - * Creates a new #GFileInputStream which will either download the model upon the first + * Creates a new #GGMLCachedModelIstream which will either download the model upon the first * read, or return a cached version from the disk. * - * Returns: (transfer full): A #GFileInputStream on success, %NULL with @error set on failure. + * Returns: (transfer full): A #GGMLCachedModelIstream on success, %NULL with @error set on failure. */ -GFileInputStream * +GGMLCachedModelIstream * ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, GError **error) { diff --git a/ggml-gobject/ggml-language-model.h b/ggml-gobject/ggml-language-model.h index a9ab5ba..52ad9b0 100644 --- a/ggml-gobject/ggml-language-model.h +++ b/ggml-gobject/ggml-language-model.h @@ -91,8 +91,8 @@ void ggml_language_model_load_defined_from_istream_async (GGMLDefinedLanguageMod GGMLLanguageModel *ggml_language_model_load_defined_from_istream_finish (GAsyncResult *result, GError **error); -GFileInputStream *ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, - GError **error); +GGMLCachedModelIstream *ggml_language_model_stream_from_cache (GGMLDefinedLanguageModel defined_model, + GError **error); char * ggml_language_model_complete (GGMLLanguageModel *language_model, const char *prompt, From eb7a62627ccf8868fbc4ba2998945a462b13e3f3 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Fri, 21 Jul 2023 23:58:07 +0300 Subject: [PATCH 4/6] ggml-language-model: Add enum definitions for other model sizes --- ggml-gobject/ggml-language-model.c | 22 ++++++++++++++++++++-- ggml-gobject/ggml-language-model.h | 5 ++++- tests/js/testLoadGPT2.js | 26 +++++++++++++------------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/ggml-gobject/ggml-language-model.c b/ggml-gobject/ggml-language-model.c index 2111a70..5b41c04 100644 --- a/ggml-gobject/ggml-language-model.c +++ b/ggml-gobject/ggml-language-model.c @@ -945,7 +945,22 @@ static struct GGMLLanguageModelDefinitions { GGMLModelDescFromHyperparametersFunc model_desc_from_hyperparameters_func; GGMLModelForwardFunc forward_func; } ggml_language_model_definitions[] = { - /* GGML_DEFINED_MODEL_GPT2 */ + /* GGML_DEFINED_MODEL_GPT2P117M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P345M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P774M */ + { + .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, + .forward_func = ggml_gpt_model_forward_pass + }, + /* GGML_DEFINED_MODEL_GPT2P1558M */ { .model_desc_from_hyperparameters_func = (GGMLModelDescFromHyperparametersFunc) ggml_create_gpt2_model_desc_from_hyperparameters, .forward_func = ggml_gpt_model_forward_pass @@ -1284,7 +1299,10 @@ ggml_language_model_load_defined_from_istream_async (GGMLDefinedLanguageModel } static const char *ggml_language_model_urls[] = { - "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin" + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-345M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-774M.bin", + "https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-1558M.bin" }; #define GGML_GOBJECT_MODELS_VERSION "0" diff --git a/ggml-gobject/ggml-language-model.h b/ggml-gobject/ggml-language-model.h index 52ad9b0..5db7816 100644 --- a/ggml-gobject/ggml-language-model.h +++ b/ggml-gobject/ggml-language-model.h @@ -53,7 +53,10 @@ void ggml_language_model_consume_istream_magic_async (GInputStream *istr gpointer user_data); typedef enum { - GGML_DEFINED_LANGUAGE_MODEL_GPT2 + GGML_DEFINED_LANGUAGE_MODEL_GPT2P117M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P345M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P774M, + GGML_DEFINED_LANGUAGE_MODEL_GPT2P1558M, } GGMLDefinedLanguageModel; GGMLLanguageModel *ggml_language_model_load_from_istream (GInputStream *istream, diff --git a/tests/js/testLoadGPT2.js b/tests/js/testLoadGPT2.js index f90ce38..22a22f8 100644 --- a/tests/js/testLoadGPT2.js +++ b/tests/js/testLoadGPT2.js @@ -371,7 +371,7 @@ describe('GGML GPT2', function() { const model_desc = createModelDescGPT2(n_inp, d_model, d_ff, n_layers, n_ctx); }); it('can load the GPT2 weights from a bin file', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P117M); const language_model = GGML.LanguageModel.load_from_istream( istream, @@ -387,7 +387,7 @@ describe('GGML GPT2', function() { ); }); it('can load the GPT2 weights from a bin file asynchronously', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P117M); GGML.LanguageModel.load_from_istream_async( istream, @@ -407,10 +407,10 @@ describe('GGML GPT2', function() { ); }); it('can load the GPT2 weights from a bin file asynchronously (defined)', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); GGML.LanguageModel.load_defined_from_istream_async( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null, (src, res) => { @@ -420,10 +420,10 @@ describe('GGML GPT2', function() { ); }); it('can do a forward pass through some data', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -433,10 +433,10 @@ describe('GGML GPT2', function() { ); }); it('can handle cancellation', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -448,10 +448,10 @@ describe('GGML GPT2', function() { expect(() => language_model.complete('The meaning of life is:', 7, cancellable)).toThrow(); }); it('can do a forward pass through some data asynchronously', function(done) { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -478,10 +478,10 @@ describe('GGML GPT2', function() { thread.join(); }); it('can handle cancellation on asynchronous completions', (done) => { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_defined_from_istream( - GGML.DefinedLanguageModel.GPT2, + GGML.DefinedLanguageModel.GPT2P177M, istream, null ); @@ -512,7 +512,7 @@ describe('GGML GPT2', function() { thread.join(); }); it('can do a forward pass defined in JS through some data', function() { - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); + const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2P177M); const language_model = GGML.LanguageModel.load_from_istream( istream, From f24ecb026a599edff8a36402e4e37a84dafee4b0 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Fri, 21 Jul 2023 23:59:12 +0300 Subject: [PATCH 5/6] llm-writer-app: Add selection of models and download feedback When selecting a new model, we download it straight away - we also start downloading the model on the first user interaction which requests a completion. Because the download can take a while, a progress bar is displayed in an overlay in the bottom right hand corner. The download doesn't get cancelled if the user cancels the prediction operation, but it will get cancelled if they change the model type --- examples/llm-writer-app/data/main.ui | 39 +++-- examples/llm-writer-app/src/main.js | 233 +++++++++++++++++++++------ 2 files changed, 217 insertions(+), 55 deletions(-) diff --git a/examples/llm-writer-app/data/main.ui b/examples/llm-writer-app/data/main.ui index e04e1d7..f6ae016 100644 --- a/examples/llm-writer-app/data/main.ui +++ b/examples/llm-writer-app/data/main.ui @@ -6,22 +6,43 @@ 480 640 - - True - vertical + + true - + true - true - true - fill + vertical - + true true true fill - word + + + true + true + true + fill + word + + + + + + + + + true + true + end + end + + + false + 0.0 + Starting Download + true diff --git a/examples/llm-writer-app/src/main.js b/examples/llm-writer-app/src/main.js index 6049df1..4471866 100644 --- a/examples/llm-writer-app/src/main.js +++ b/examples/llm-writer-app/src/main.js @@ -39,16 +39,137 @@ const STATE_TEXT_EDITOR = 0; const STATE_PREDICTING = 1; const STATE_WAITING = 2; +const list_store_from_rows = (rows) => { + const list_store = Gtk.ListStore.new(rows[0].map(() => GObject.TYPE_STRING)); + + rows.forEach(columns => { + const iter = list_store.append(); + columns.forEach((c, i) => { + list_store.set_value(iter, i, c) + }); + }); + + return list_store; +}; + +const load_model = (model, cancellable, callback, progress_callback) => { + const istream = GGML.LanguageModel.stream_from_cache(model); + + if (progress_callback) { + istream.set_download_progress_callback(progress_callback); + } + + GGML.LanguageModel.load_defined_from_istream_async( + model, + istream, + cancellable, + (src, res) => { + try { + callback(GGML.LanguageModel.load_defined_from_istream_finish(res)); + } catch (e) { + if (e.code === Gio.IOErrorEnum.CANCELLED) { + return; + } + logError(e); + } + } + ); +}; + +const COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM = Object.keys(GGML.DefinedLanguageModel).map(k => GGML.DefinedLanguageModel[k]); + +class ModelLoader { + constructor() { + this._model_enum = null; + this._model = null; + this._pending_load = null; + } + + /** + * with_model: + * @model_enum: A #GGMLModelDescription + * @cancellable: A #GCancellable + * @callback: A callback to invoke once the model is done loading + * + * Does some action with a model. Also accepts a @cancellable - + * if the action is cancelled, then @callback won't be invoked, but + * the model will stil be downloaded if the download is in progress. + */ + with_model(model_enum, cancellable, callback, progress_callback) { + if (this._model_enum === model_enum) { + return callback(this._model) + } + + if (this._pending_load) { + /* We only do the most recent callback once the model is loaded + * and discard other ones */ + if (this._pending_load.model_enum !== model_enum) { + /* Cancel the existing pending load and start over again */ + this._pending_load.load_cancellable.cancel(); + } else { + /* Don't cancel the pending load operation, but change the callback */ + this._pending_load = { + model_enum: model_enum, + callback: callback, + load_cancellable: this._pending_load.load_cancellable, + action_cancellable: cancellable + }; + return; + } + } + + /* Create a pending load and load the model */ + this._pending_load = { + model_enum: model_enum, + callback: callback, + load_cancellable: new Gio.Cancellable(), + action_cancellable: cancellable + }; + + load_model(model_enum, this._pending_load.load_cancellable, model => { + const { callback, action_cancellable } = this._pending_load; + + if (action_cancellable === null || !action_cancellable.is_cancelled()) { + this._model_enum = model_enum; + this._model = model; + + System.gc(); + return callback(this._model); + } + }, progress_callback); + } +} + const LLMWriterAppMainWindow = GObject.registerClass({ Template: `${RESOURCE_PATH}/main.ui`, Children: [ 'content-view', - 'text-view' + 'text-view', + 'progress-bar' ] }, class LLMWriterAppMainWindow extends Gtk.ApplicationWindow { _init(params) { super._init(params); + this._model_loader = new ModelLoader(); + + const resetProgress = () => { + this.progress_bar.set_visible(false); + this.progress_bar.set_text("Starting Download"); + }; + const progressCallback = (received_bytes, total_bytes) => { + if (received_bytes === -1) { + resetProgress(); + return; + } + + const fraction = received_bytes / total_bytes; + + this.progress_bar.set_visible(true); + this.progress_bar.set_fraction(fraction); + this.progress_bar.set_text(`Downloading ${Math.trunc(fraction * 100)}%`); + }; + const header = new Gtk.HeaderBar({ visible: true, title: GLib.get_application_name(), @@ -57,10 +178,33 @@ const LLMWriterAppMainWindow = GObject.registerClass({ this._spinner = new Gtk.Spinner({ visible: true }); + const combobox = Gtk.ComboBox.new_with_model( + list_store_from_rows([ + ['GPT2 117M'], + ['GPT2 345M'], + ['GPT2 774M'], + ['GPT2 1558M'], + ]) + ); + const renderer = new Gtk.CellRendererText(); + combobox.pack_start(renderer, true); + combobox.add_attribute(renderer, 'text', 0); + combobox.set_active(0); + combobox.connect('changed', () => { + resetProgress(); + this._model_loader.with_model( + COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active], + null, + () => this._spinner.stop(), + progressCallback + ); + }); + combobox.show(); + + header.pack_start(combobox); header.pack_end(this._spinner); this.set_titlebar(header); - this._languageModel = null; this._textBufferState = STATE_TEXT_EDITOR; this._predictionsStartedAt = -1; this._cancellable = null; @@ -101,7 +245,6 @@ const LLMWriterAppMainWindow = GObject.registerClass({ if (currentPosition > 0 && currentPosition === this._lastCursorOffset && count > 0 && - this._languageModel !== null && this._textBufferState === STATE_TEXT_EDITOR) { const text = buffer.get_text( buffer.get_start_iter(), @@ -109,48 +252,58 @@ const LLMWriterAppMainWindow = GObject.registerClass({ false ); - this.text_view.set_editable(false); + /* Reset state immediately if the operation is cancelled */ this._cancellable = new Gio.Cancellable({}); + this._cancellable.connect(() => resetState()); + this._textBufferState = STATE_PREDICTING; this._candidateText = ''; this._spinner.start(); buffer.create_mark("predictions-start", buffer.get_end_iter(), true); - this._languageModel.complete_async( - text, - 10, - 2, + + this._model_loader.with_model( + COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[combobox.active], this._cancellable, - (src, res) => { - let part, is_complete, is_complete_eos; - try { - [part, is_complete, is_complete_eos] = this._languageModel.complete_finish(res); - } catch (e) { - if (e.code == Gio.IOErrorEnum.CANCELLED) { - resetState(); + model => { + model.complete_async( + text, + 10, + 2, + this._cancellable, + (src, res) => { + let part, is_complete, is_complete_eos; + try { + [part, is_complete, is_complete_eos] = model.complete_finish(res); + } catch (e) { + if (e.code == Gio.IOErrorEnum.CANCELLED) { + return; + } + logError(e); + return; + } + + if (part === text) { + return; + } + + if (is_complete) { + this._cancellable = null; + this._textBufferState = STATE_WAITING; + this._spinner.stop(); + } + + this._candidateText += part; + const markup = `${GLib.markup_escape_text(part, part.length)}` + buffer.insert_markup(buffer.get_end_iter(), markup, markup.length); + System.gc(); } - return; - } - - if (part === text) { - return; - } - - if (is_complete) { - this._cancellable = null; - this._textBufferState = STATE_WAITING; - this._spinner.stop(); - } - - this._candidateText += part; - const markup = `${GLib.markup_escape_text(part, part.length)}` - buffer.insert_markup(buffer.get_end_iter(), markup, markup.length); - System.gc(); - } + ); + }, + progressCallback ); } else if (currentPosition > 0 && currentPosition === this._lastCursorOffset && count > 0 && - this._languageModel !== null && this._textBufferState === STATE_WAITING) { // Delete the gray text and substitute the real text. removePredictedText(); @@ -184,18 +337,6 @@ const LLMWriterAppMainWindow = GObject.registerClass({ } vfunc_show() { - this._spinner.start(); - const istream = GGML.LanguageModel.stream_from_cache(GGML.DefinedLanguageModel.GPT2); - GGML.LanguageModel.load_defined_from_istream_async( - GGML.DefinedLanguageModel.GPT2, - istream, - null, - (src, res) => { - this._languageModel = GGML.LanguageModel.load_defined_from_istream_finish(res); - this._spinner.stop(); - } - ); - super.vfunc_show(); } }); From 4f5ee70b4a77710c9b3c210d77452b9c8f561636 Mon Sep 17 00:00:00 2001 From: Sam Spilsbury Date: Sat, 22 Jul 2023 00:20:38 +0300 Subject: [PATCH 6/6] ggml-cached-model: Also send sentinel value in error case --- ggml-gobject/ggml-cached-model.c | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/ggml-gobject/ggml-cached-model.c b/ggml-gobject/ggml-cached-model.c index 6d6b420..c21de7f 100644 --- a/ggml-gobject/ggml-cached-model.c +++ b/ggml-gobject/ggml-cached-model.c @@ -212,13 +212,22 @@ ggml_cached_model_istream_ensure_stream (GGMLCachedModelIstream *cached_mode g_object_unref); } - if (!g_output_stream_splice (G_OUTPUT_STREAM (output_stream), - G_INPUT_STREAM (progress_istream), - G_OUTPUT_STREAM_SPLICE_CLOSE_SOURCE | - G_OUTPUT_STREAM_SPLICE_CLOSE_TARGET, - cancellable, - error)) + if (g_output_stream_splice (G_OUTPUT_STREAM (output_stream), + G_INPUT_STREAM (progress_istream), + G_OUTPUT_STREAM_SPLICE_CLOSE_SOURCE | + G_OUTPUT_STREAM_SPLICE_CLOSE_TARGET, + cancellable, + error) == -1) { + /* We send the sentinel value to the progress callback on the error + * case too, so that it can clean up */ + if (priv->progress_callback != NULL) + { + ggml_cached_model_push_download_progress_to_queue (-1, + priv->remote_content_length, + cached_model); + } + return NULL; }