Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-token unicode character support #95

Merged
merged 5 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

126 changes: 22 additions & 104 deletions crates/edgen_rt_llama_cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ use std::task::{Context, Poll};
use blake3::Hasher;
use dashmap::DashMap;
use futures::executor::block_on;
use futures::{Future, Stream};
use futures::Stream;
use llama_cpp::standard_sampler::StandardSampler;
use llama_cpp::{CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, Token};
use smol::future::FutureExt;
use llama_cpp::{
CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, TokensToStrings,
};
use smol::stream::StreamExt;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::{interval, MissedTickBehavior};
use tokio::{select, spawn};
Expand Down Expand Up @@ -240,23 +241,13 @@ impl UnloadingModel {
.map_err(move |e| LLMEndpointError::Advance(e.to_string()))?;

let sampler = StandardSampler::default();
let mut handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

let mut res = String::default();
while let Some(token) = handle.next_token_async().await {
if token == model_guard.eos() {
break;
}

let piece = model_guard.token_to_piece(token);
res += &piece;
}
let handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

Ok(res)
Ok(model_guard.decode_tokens(handle))
} else {
let (session, mut id, new_context) = self.take_chat_session(&args.prompt).await;

let (_session_signal, mut handle) = {
let (_session_signal, handle) = {
let (session_signal, mut session_guard) =
get_or_init_session(&session, model_guard.clone()).await?;

Expand All @@ -272,16 +263,7 @@ impl UnloadingModel {
(session_signal, handle)
};

let mut res = String::default();
while let Some(token) = handle.next_token_async().await {
if token == model_guard.eos() {
break;
}

let piece = model_guard.token_to_piece(token);
res += &piece;
id.advance(&piece);
}
let res = model_guard.decode_tokens(handle);

self.sessions.insert(id, session);

Expand Down Expand Up @@ -314,14 +296,7 @@ impl UnloadingModel {
let sampler = StandardSampler::default();

Ok(Box::new(
CompletionStream::new_oneshot(
session,
&args.prompt,
model_guard.clone(),
model_signal,
sampler,
)
.await?,
CompletionStream::new_oneshot(session, &args.prompt, model_signal, sampler).await?,
))
} else {
let (session, id, new_context) = self.take_chat_session(&args.prompt).await;
Expand Down Expand Up @@ -524,20 +499,8 @@ fn find_any(text: &str, patterns: &[&str]) -> Option<usize> {

/// A [`Stream`] of [`Token`]s returned by a [`LlamaCppSession::stream_complete`] call.
struct CompletionStream {
/// The [`LlamaModel`] used to call [`LlamaModel::token_to_piece`].
model: LlamaModel,

// TODO look better into this implementation, could try sending this across channels instead
/// Handle to the model completions, needs to be an [`Arc`] so it can be cloned in
/// [`Stream::poll_next`], or else it would be referencing a vanishing `self`.
handle: Arc<Mutex<CompletionHandle>>,

//TODO i dont know if it is possible to do this without a box
/// An [`Option`] potentially containing the result of a [`CompletionHandle::next_token_async`] call.
next: Option<Pin<Box<dyn Future<Output = Option<Token>> + Send>>>,

/// The *end of sequence* [`Token`] of the [`LlamaModel`] this [`CompletionStream`] is associated with.
end_token: Token,
/// Handle to the model completions handle.
handle: TokensToStrings<CompletionHandle>,

/// The session used for generation completions.
session: SessionOption,
Expand Down Expand Up @@ -576,9 +539,6 @@ impl CompletionStream {
sampler: StandardSampler,
finished_tx: UnboundedSender<(SessionId, Perishable<LlamaSession>)>,
) -> Result<Self, LLMEndpointError> {
let model_clone = model.clone();
let end_token = model.eos();

let (session_signal, handle) = {
let (session_signal, mut session_guard) = get_or_init_session(&session, model).await?;

Expand All @@ -595,10 +555,7 @@ impl CompletionStream {
};

Ok(Self {
model: model_clone,
handle: Arc::new(Mutex::new(handle)),
next: None,
end_token,
handle: handle.into_strings(),
session: SessionOption::Perishable(session),
session_id: Some(session_id),
finished_tx: Some(finished_tx),
Expand All @@ -610,56 +567,24 @@ impl CompletionStream {
async fn new_oneshot(
mut session: LlamaSession,
new_context: &str,
model: LlamaModel,
model_signal: ActiveSignal,
sampler: StandardSampler,
) -> Result<Self, LLMEndpointError> {
let model_clone = model.clone();
let end_token = model.eos();

session
.advance_context_async(new_context)
.await
.map_err(move |e| LLMEndpointError::Advance(e.to_string()))?;
let handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT);

Ok(Self {
model: model_clone,
handle: Arc::new(Mutex::new(handle)),
next: None,
end_token,
handle: handle.into_strings(),
session: SessionOption::OneShot(session),
session_id: None,
finished_tx: None,
_model_signal: model_signal,
_session_signal: None,
})
}

/// Helper function that captures the provided [`CompletionHandle`] handle [`Arc`] clone and
/// calls [`CompletionHandle::next_token_async`].
async fn get_next(handle: Arc<Mutex<CompletionHandle>>) -> Option<Token> {
handle.lock().await.next_token_async().await
}

/// Small helper function that takes in an acquired [`Poll::Ready`] value and returns the [`String`]
/// representation of the contained [`Token`]. If the [`Token`] is either not present or the *end of sequence*
/// [`Token`], return [`Option::None`].
fn poll_result(&mut self, result: Option<Token>) -> Poll<Option<String>> {
if let Some(token) = result {
if token != self.end_token {
let piece = self.model.token_to_piece(token);
if let Some(ref mut id) = &mut self.session_id {
id.advance(&piece);
}
Poll::Ready(Some(piece))
} else {
Poll::Ready(None)
}
} else {
Poll::Ready(None)
}
}
}

impl Stream for CompletionStream {
Expand All @@ -668,22 +593,15 @@ impl Stream for CompletionStream {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = std::ops::DerefMut::deref_mut(&mut self);

if let Some(next) = stream.next.as_mut() {
if let Poll::Ready(val) = next.poll(cx) {
stream.next = None;
stream.poll_result(val)
} else {
Poll::Pending
}
} else {
let mut fut = Box::pin(Self::get_next(stream.handle.clone()));

if let Poll::Ready(val) = fut.poll(cx) {
stream.poll_result(val)
} else {
stream.next = Some(fut);
Poll::Pending
match stream.handle.poll_next(cx) {
Poll::Ready(Some(val)) => {
if let Some(id) = &mut stream.session_id {
id.advance(&val);
}
Poll::Ready(Some(val))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
Expand Down
Loading