diff --git a/Cargo.lock b/Cargo.lock index 36e6e22..9517c8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2883,7 +2883,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama_cpp" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#6b599bd9a3a9b5e82d680cd2df20691d0252095f" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#0bdae533e6575685a6c2add27a2222a2dca350d2" dependencies = [ "derive_more", "futures", @@ -2897,7 +2897,7 @@ dependencies = [ [[package]] name = "llama_cpp_sys" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#6b599bd9a3a9b5e82d680cd2df20691d0252095f" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#0bdae533e6575685a6c2add27a2222a2dca350d2" dependencies = [ "ash", "bindgen", diff --git a/crates/edgen_rt_llama_cpp/src/lib.rs b/crates/edgen_rt_llama_cpp/src/lib.rs index 5760d8a..33d1c04 100644 --- a/crates/edgen_rt_llama_cpp/src/lib.rs +++ b/crates/edgen_rt_llama_cpp/src/lib.rs @@ -19,12 +19,13 @@ use std::task::{Context, Poll}; use blake3::Hasher; use dashmap::DashMap; use futures::executor::block_on; -use futures::{Future, Stream}; +use futures::Stream; use llama_cpp::standard_sampler::StandardSampler; -use llama_cpp::{CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, Token}; -use smol::future::FutureExt; +use llama_cpp::{ + CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, TokensToStrings, +}; +use smol::stream::StreamExt; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{interval, MissedTickBehavior}; use tokio::{select, spawn}; @@ -240,23 +241,13 @@ impl UnloadingModel { .map_err(move |e| LLMEndpointError::Advance(e.to_string()))?; let sampler = StandardSampler::default(); - let mut handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT); - - let mut res = String::default(); - while let Some(token) = handle.next_token_async().await { - if token == model_guard.eos() { - break; - } - - let piece = model_guard.token_to_piece(token); - res += &piece; - } + let handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT); - Ok(res) + Ok(model_guard.decode_tokens(handle)) } else { let (session, mut id, new_context) = self.take_chat_session(&args.prompt).await; - let (_session_signal, mut handle) = { + let (_session_signal, handle) = { let (session_signal, mut session_guard) = get_or_init_session(&session, model_guard.clone()).await?; @@ -272,16 +263,7 @@ impl UnloadingModel { (session_signal, handle) }; - let mut res = String::default(); - while let Some(token) = handle.next_token_async().await { - if token == model_guard.eos() { - break; - } - - let piece = model_guard.token_to_piece(token); - res += &piece; - id.advance(&piece); - } + let res = model_guard.decode_tokens(handle); self.sessions.insert(id, session); @@ -314,14 +296,7 @@ impl UnloadingModel { let sampler = StandardSampler::default(); Ok(Box::new( - CompletionStream::new_oneshot( - session, - &args.prompt, - model_guard.clone(), - model_signal, - sampler, - ) - .await?, + CompletionStream::new_oneshot(session, &args.prompt, model_signal, sampler).await?, )) } else { let (session, id, new_context) = self.take_chat_session(&args.prompt).await; @@ -524,20 +499,8 @@ fn find_any(text: &str, patterns: &[&str]) -> Option { /// A [`Stream`] of [`Token`]s returned by a [`LlamaCppSession::stream_complete`] call. struct CompletionStream { - /// The [`LlamaModel`] used to call [`LlamaModel::token_to_piece`]. - model: LlamaModel, - - // TODO look better into this implementation, could try sending this across channels instead - /// Handle to the model completions, needs to be an [`Arc`] so it can be cloned in - /// [`Stream::poll_next`], or else it would be referencing a vanishing `self`. - handle: Arc>, - - //TODO i dont know if it is possible to do this without a box - /// An [`Option`] potentially containing the result of a [`CompletionHandle::next_token_async`] call. - next: Option> + Send>>>, - - /// The *end of sequence* [`Token`] of the [`LlamaModel`] this [`CompletionStream`] is associated with. - end_token: Token, + /// Handle to the model completions handle. + handle: TokensToStrings, /// The session used for generation completions. session: SessionOption, @@ -576,9 +539,6 @@ impl CompletionStream { sampler: StandardSampler, finished_tx: UnboundedSender<(SessionId, Perishable)>, ) -> Result { - let model_clone = model.clone(); - let end_token = model.eos(); - let (session_signal, handle) = { let (session_signal, mut session_guard) = get_or_init_session(&session, model).await?; @@ -595,10 +555,7 @@ impl CompletionStream { }; Ok(Self { - model: model_clone, - handle: Arc::new(Mutex::new(handle)), - next: None, - end_token, + handle: handle.into_strings(), session: SessionOption::Perishable(session), session_id: Some(session_id), finished_tx: Some(finished_tx), @@ -610,13 +567,9 @@ impl CompletionStream { async fn new_oneshot( mut session: LlamaSession, new_context: &str, - model: LlamaModel, model_signal: ActiveSignal, sampler: StandardSampler, ) -> Result { - let model_clone = model.clone(); - let end_token = model.eos(); - session .advance_context_async(new_context) .await @@ -624,10 +577,7 @@ impl CompletionStream { let handle = session.start_completing_with(sampler, SINGLE_MESSAGE_LIMIT); Ok(Self { - model: model_clone, - handle: Arc::new(Mutex::new(handle)), - next: None, - end_token, + handle: handle.into_strings(), session: SessionOption::OneShot(session), session_id: None, finished_tx: None, @@ -635,31 +585,6 @@ impl CompletionStream { _session_signal: None, }) } - - /// Helper function that captures the provided [`CompletionHandle`] handle [`Arc`] clone and - /// calls [`CompletionHandle::next_token_async`]. - async fn get_next(handle: Arc>) -> Option { - handle.lock().await.next_token_async().await - } - - /// Small helper function that takes in an acquired [`Poll::Ready`] value and returns the [`String`] - /// representation of the contained [`Token`]. If the [`Token`] is either not present or the *end of sequence* - /// [`Token`], return [`Option::None`]. - fn poll_result(&mut self, result: Option) -> Poll> { - if let Some(token) = result { - if token != self.end_token { - let piece = self.model.token_to_piece(token); - if let Some(ref mut id) = &mut self.session_id { - id.advance(&piece); - } - Poll::Ready(Some(piece)) - } else { - Poll::Ready(None) - } - } else { - Poll::Ready(None) - } - } } impl Stream for CompletionStream { @@ -668,22 +593,15 @@ impl Stream for CompletionStream { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream = std::ops::DerefMut::deref_mut(&mut self); - if let Some(next) = stream.next.as_mut() { - if let Poll::Ready(val) = next.poll(cx) { - stream.next = None; - stream.poll_result(val) - } else { - Poll::Pending - } - } else { - let mut fut = Box::pin(Self::get_next(stream.handle.clone())); - - if let Poll::Ready(val) = fut.poll(cx) { - stream.poll_result(val) - } else { - stream.next = Some(fut); - Poll::Pending + match stream.handle.poll_next(cx) { + Poll::Ready(Some(val)) => { + if let Some(id) = &mut stream.session_id { + id.advance(&val); + } + Poll::Ready(Some(val)) } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } }