Skip to content

Commit

Permalink
Merge pull request #9 from smspillaz/language-model-cursor
Browse files Browse the repository at this point in the history
language-model: Rework into "cursors" model
  • Loading branch information
smspillaz authored Jul 31, 2023
2 parents e2ab81a + 6173d41 commit 6927861
Show file tree
Hide file tree
Showing 11 changed files with 1,002 additions and 417 deletions.
107 changes: 59 additions & 48 deletions examples/llm-writer-app/src/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ const LLMWriterAppMainWindow = GObject.registerClass({
super._init(params);

this._model_loader = new ModelLoader();
this._cursor = null;

const resetProgress = () => {
this.progress_bar.set_visible(false);
Expand Down Expand Up @@ -224,7 +225,10 @@ const LLMWriterAppMainWindow = GObject.registerClass({
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[modelCombobox.active],
COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM[quantizationCombobox.active],
null,
() => this._spinner.stop(),
() => {
this._spinner.stop();
this._cursor = null;
},
progressCallback
);
};
Expand Down Expand Up @@ -271,6 +275,8 @@ const LLMWriterAppMainWindow = GObject.registerClass({
this._candidateText = '';
this.text_view.set_editable(true);
this._spinner.stop();
this._cursor = null;
System.gc();
};
const maybeAbortPrediction = () => {
if (this._textBufferState === STATE_PREDICTING) {
Expand All @@ -283,6 +289,38 @@ const LLMWriterAppMainWindow = GObject.registerClass({
resetState();
}
};
const predictFunc = (cursor, n_tokens, prompt, textBuffer) => {
cursor.exec_stream_async(
n_tokens,
2,
this._cancellable,
(part, is_complete_eos) => {
if (part === prompt) {
return;
}

this._candidateText += part;
const markup = `<span foreground="gray">${GLib.markup_escape_text(part, part.length)}</span>`
textBuffer.insert_markup(textBuffer.get_end_iter(), markup, markup.length);
System.gc();
},
(src, res) => {
try {
cursor.exec_stream_finish(res);
} catch (e) {
if (e.code == Gio.IOErrorEnum.CANCELLED) {
return;
}
logError(e);
return;
}

this._cancellable = null;
this._textBufferState = STATE_WAITING;
this._spinner.stop();
}
);
};

this.text_view.connect('move-cursor', (obj, step, count, extend_selection) => {
const currentPosition = buffer.cursor_position;
Expand All @@ -292,62 +330,35 @@ const LLMWriterAppMainWindow = GObject.registerClass({
currentPosition === this._lastCursorOffset &&
count > 0 &&
this._textBufferState === STATE_TEXT_EDITOR) {
const text = buffer.get_text(
buffer.get_start_iter(),
buffer.get_end_iter(),
false
);

/* Reset state immediately if the operation is cancelled */
this._cancellable = new Gio.Cancellable({});
this._cancellable.connect(() => resetState());

this._textBufferState = STATE_PREDICTING;
this._candidateText = '';
this._spinner.start();
buffer.create_mark("predictions-start", buffer.get_end_iter(), true);

this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[modelCombobox.active],
COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM[quantizationCombobox.active],
this._cancellable,
model => {
model.complete_async(
text,
10,
2,
this._cancellable,
(src, res) => {
let part, is_complete, is_complete_eos;
try {
[part, is_complete, is_complete_eos] = model.complete_finish(res);
} catch (e) {
if (e.code == Gio.IOErrorEnum.CANCELLED) {
return;
}
logError(e);
return;
}

if (part === text) {
return;
}

if (is_complete) {
this._cancellable = null;
this._textBufferState = STATE_WAITING;
this._spinner.stop();
}

this._candidateText += part;
const markup = `<span foreground="gray">${GLib.markup_escape_text(part, part.length)}</span>`
buffer.insert_markup(buffer.get_end_iter(), markup, markup.length);
System.gc();
}
);
},
progressCallback
);
if (this._cursor !== null) {
predictFunc(this._cursor, 10, null, buffer);
} else {
const text = buffer.get_text(
buffer.get_start_iter(),
buffer.get_end_iter(),
false
);

this._model_loader.with_model(
COMBOBOX_ID_TO_LANGUAGE_MODEL_ENUM[modelCombobox.active],
COMBOBOX_ID_TO_QUANTIZATION_LEVEL_ENUM[quantizationCombobox.active],
this._cancellable,
model => {
this._cursor = model.create_completion(text, 256);
predictFunc(this._cursor, 10, text, buffer);
},
progressCallback
);
}
} else if (currentPosition > 0 &&
currentPosition === this._lastCursorOffset &&
count > 0 &&
Expand Down
2 changes: 2 additions & 0 deletions ggml-gobject/ggml-cached-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#pragma once

#include <gio/gio.h>
#include <glib-object.h>

Expand Down
2 changes: 2 additions & 0 deletions ggml-gobject/ggml-gobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#pragma once

#include <ggml-gobject/ggml-cached-model.h>
#include <ggml-gobject/ggml-compute-graph.h>
#include <ggml-gobject/ggml-context.h>
#include <ggml-gobject/ggml-gpt.h>
Expand All @@ -30,6 +31,7 @@
#include <ggml-gobject/ggml-model-desc.h>
#include <ggml-gobject/ggml-model.h>
#include <ggml-gobject/ggml-ops.h>
#include <ggml-gobject/ggml-quantize.h>
#include <ggml-gobject/ggml-tensor.h>
#include <ggml-gobject/ggml-token-dictionary.h>
#include <ggml-gobject/ggml-types.h>
Expand Down
Loading

0 comments on commit 6927861

Please sign in to comment.