diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 2967c38b9b..77e8c3d46e 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -70,7 +70,7 @@ def __init__(self, settings: Settings) -> None: ollama_settings = settings.ollama self.embedding_model = OllamaEmbedding( model_name=ollama_settings.embedding_model, - base_url=ollama_settings.api_base, + base_url=ollama_settings.embedding_api_base, ) case "azopenai": try: diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 4e46c250b7..dae997cc28 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -1,4 +1,6 @@ import logging +from collections.abc import Callable +from typing import Any from injector import inject, singleton from llama_index.core.llms import LLM, MockLLM @@ -133,6 +135,24 @@ def __init__(self, settings: Settings) -> None: additional_kwargs=settings_kwargs, request_timeout=ollama_settings.request_timeout, ) + + if ( + ollama_settings.keep_alive + != ollama_settings.model_fields["keep_alive"].default + ): + # Modify Ollama methods to use the "keep_alive" field. + def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + kwargs["keep_alive"] = ollama_settings.keep_alive + return func(*args, **kwargs) + + return wrapper + + Ollama.chat = add_keep_alive(Ollama.chat) + Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) + Ollama.complete = add_keep_alive(Ollama.complete) + Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) + case "azopenai": try: from llama_index.llms.azure_openai import ( # type: ignore diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index bc03e30aa2..7ec84a7b48 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -209,6 +209,10 @@ class OllamaSettings(BaseModel): "http://localhost:11434", description="Base URL of Ollama API. Example: 'https://localhost:11434'.", ) + embedding_api_base: str = Field( + api_base, # default is same as api_base, unless specified differently + description="Base URL of Ollama embedding API. Defaults to the same value as api_base", + ) llm_model: str = Field( None, description="Model to use. Example: 'llama2-uncensored'.", @@ -217,6 +221,10 @@ class OllamaSettings(BaseModel): None, description="Model to use. Example: 'nomic-embed-text'.", ) + keep_alive: str = Field( + "5m", + description="Time the model will stay loaded in memory after a request. examples: 5m, 5h, '-1' ", + ) tfs_z: float = Field( 1.0, description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.", diff --git a/settings-ollama.yaml b/settings-ollama.yaml index d7e1a12ca0..4f0be4ffc8 100644 --- a/settings-ollama.yaml +++ b/settings-ollama.yaml @@ -14,6 +14,8 @@ ollama: llm_model: mistral embedding_model: nomic-embed-text api_base: http://localhost:11434 + keep_alive: 5m + # embedding_api_base: http://ollama_embedding:11434 # uncomment if your embedding model runs on another ollama tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) diff --git a/settings.yaml b/settings.yaml index ce6a2b9faf..11c3c42f67 100644 --- a/settings.yaml +++ b/settings.yaml @@ -99,6 +99,8 @@ ollama: llm_model: llama2 embedding_model: nomic-embed-text api_base: http://localhost:11434 + keep_alive: 5m + # embedding_api_base: http://ollama_embedding:11434 # uncomment if your embedding model runs on another ollama request_timeout: 120.0 azopenai: