diff --git a/lib/chat_models/chat_ollama_ai.ex b/lib/chat_models/chat_ollama_ai.ex new file mode 100644 index 0000000..0b4a1b1 --- /dev/null +++ b/lib/chat_models/chat_ollama_ai.ex @@ -0,0 +1,370 @@ +defmodule LangChain.ChatModels.ChatOllamaAI do + @moduledoc """ + Represents the [Ollama AI Chat model](https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion) + + Parses and validates inputs for making a requests from the Ollama Chat API. + + Converts responses into more specialized `LangChain` data structures. + + The module's functionalities include: + + - Initializing a new `ChatOllamaAI` struct with defaults or specific attributes. + - Validating and casting input data to fit the expected schema. + - Preparing and sending requests to the Ollama AI service API. + - Managing both streaming and non-streaming API responses. + - Processing API responses to convert them into suitable message formats. + + The `ChatOllamaAI` struct has fields to configure the AI, including but not limited to: + + - `endpoint`: URL of the Ollama AI service. + - `model`: The AI model used, e.g., "llama2:latest". + - `receive_timeout`: Max wait time for AI service responses. + - `temperature`: Influences the AI's response creativity. + + For detailed info on on all other parameters see documentation here: + https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + + This module is for use within LangChain and follows the `ChatModel` behavior, + outlining callbacks AI chat models must implement. + + Usage examples and more details are in the LangChain documentation or the + module's function docs. + """ + use Ecto.Schema + require Logger + import Ecto.Changeset + alias __MODULE__ + alias LangChain.ChatModels.ChatModel + alias LangChain.Message + alias LangChain.MessageDelta + alias LangChain.LangChainError + alias LangChain.ForOpenAIApi + alias LangChain.Utils + + @behaviour ChatModel + + @type t :: %ChatOllamaAI{} + + @create_fields [ + :endpoint, + :mirostat, + :mirostat_eta, + :mirostat_tau, + :model, + :num_ctx, + :num_gqa, + :num_gpu, + :num_predict, + :num_thread, + :receive_timeout, + :repeat_last_n, + :repeat_penalty, + :seed, + :stop, + :stream, + :temperature, + :tfs_z, + :top_k, + :top_p + ] + + @required_fields [:endpoint, :model] + + @receive_timeout 60_000 * 5 + + @primary_key false + embedded_schema do + field :endpoint, :string, default: "http://localhost:11434/api/chat" + + # Enable Mirostat sampling for controlling perplexity. + # (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + field :mirostat, :integer, default: 0 + + # Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate + # will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. + # (Default: 0.1) + field :mirostat_eta, :float, default: 0.1 + + # Controls the balance between coherence and diversity of the output. A lower value will result in more focused + # and coherent text. (Default: 5.0) + field :mirostat_tau, :float, default: 5.0 + + field :model, :string, default: "llama2:latest" + + # Sets the size of the context window used to generate the next token. (Default: 2048) + field :num_ctx, :integer, default: 2048 + + # The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b + field :num_gqa, :integer + + # The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. + field :num_gpu, :integer + + # Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) + field :num_predict, :integer, default: 128 + + # Sets the number of threads to use during computation. By default, Ollama will detect this for optimal + # performance. It is recommended to set this value to the number of physical CPU cores your system has (as + # opposed to the logical number of cores). + field :num_thread, :integer + + # Duration in seconds for the response to be received. When streaming a very + # lengthy response, a longer time limit may be required. However, when it + # goes on too long by itself, it tends to hallucinate more. + # Seems like the default for ollama is 5 minutes? https://github.com/jmorganca/ollama/pull/1257 + field :receive_timeout, :integer, default: @receive_timeout + + # Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) + field :repeat_last_n, :integer, default: 64 + + # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, + # while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + field :repeat_penalty, :float, default: 1.1 + + # Sets the random number seed to use for generation. Setting this to a specific number will make the + # model generate the same text for the same prompt. (Default: 0) + field :seed, :integer, default: 0 + + # Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. + # Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. + field :stop, :string + + field :stream, :boolean, default: false + + # The temperature of the model. Increasing the temperature will make the model + # answer more creatively. (Default: 0.8) + field :temperature, :float, default: 0.8 + + # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) + # will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) + field :tfs_z, :float, default: 1.0 + + # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, + # while a lower value (e.g. 10) will be more conservative. (Default: 40) + field :top_k, :integer, default: 40 + + # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, + # while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) + field :top_p, :float, default: 0.9 + end + + @doc """ + Creates a new `ChatOllamaAI` struct with the given attributes. + """ + @spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()} + def new(%{} = attrs \\ %{}) do + %ChatOllamaAI{} + |> cast(attrs, @create_fields, empty_values: [""]) + |> common_validation() + |> apply_action(:insert) + end + + @doc """ + Creates a new `ChatOllamaAI` struct with the given attributes. Will raise an error if the changeset is invalid. + """ + @spec new!(attrs :: map()) :: t() | no_return() + def new!(attrs \\ %{}) do + case new(attrs) do + {:ok, chain} -> + chain + + {:error, changeset} -> + raise LangChainError, changeset + end + end + + defp common_validation(changeset) do + changeset + |> validate_required(@required_fields) + |> validate_number(:temperature, greater_than_or_equal_to: 0.0, less_than_or_equal_to: 1.0) + |> validate_number(:mirostat_eta, greater_than_or_equal_to: 0.0, less_than_or_equal_to: 1.0) + end + + @doc """ + Return the params formatted for an API request. + """ + def for_api(%ChatOllamaAI{} = model, messages, _functions) do + %{ + model: model.model, + temperature: model.temperature, + messages: messages |> Enum.map(&ForOpenAIApi.for_api/1), + stream: model.stream, + seed: model.seed, + num_ctx: model.num_ctx, + num_predict: model.num_predict, + repeat_last_n: model.repeat_last_n, + repeat_penalty: model.repeat_penalty, + mirostat: model.mirostat, + mirostat_eta: model.mirostat_eta, + mirostat_tau: model.mirostat_tau, + num_gqa: model.num_gqa, + num_gpu: model.num_gpu, + num_thread: model.num_thread, + receive_timeout: model.receive_timeout, + stop: model.stop, + tfs_z: model.tfs_z, + top_k: model.top_k, + top_p: model.top_p + } + end + + @doc """ + Calls the Ollama Chat Completion API struct with configuration, plus + either a simple message or the list of messages to act as the prompt. + + **NOTE:** This API as of right now does not support functions. More + information here: https://github.com/jmorganca/ollama/issues/1729 + + **NOTE:** This function *can* be used directly, but the primary interface + should be through `LangChain.Chains.LLMChain`. The `ChatOllamaAI` module is more focused on + translating the `LangChain` data structures to and from the Ollama API. + + Another benefit of using `LangChain.Chains.LLMChain` is that it combines the + storage of messages, adding functions, adding custom context that should be + passed to functions, and automatically applying `LangChain.MessageDelta` + structs as they are are received, then converting those to the full + `LangChain.Message` once fully complete. + """ + + @impl ChatModel + def call(ollama_ai, prompt, functions \\ [], callback_fn \\ nil) + + def call(%ChatOllamaAI{} = ollama_ai, prompt, functions, callback_fn) when is_binary(prompt) do + messages = [ + Message.new_system!(), + Message.new_user!(prompt) + ] + + call(ollama_ai, messages, functions, callback_fn) + end + + def call(%ChatOllamaAI{} = ollama_ai, messages, functions, callback_fn) + when is_list(messages) do + try do + case do_api_request(ollama_ai, messages, functions, callback_fn) do + {:error, reason} -> + {:error, reason} + + parsed_data -> + {:ok, parsed_data} + end + rescue + err in LangChainError -> + {:error, err.message} + end + end + + # Make the API request from the Ollama server. + # + # The result of the function is: + # + # - `result` - where `result` is a data-structure like a list or map. + # - `{:error, reason}` - Where reason is a string explanation of what went wrong. + # + # **NOTE:** callback function are IGNORED for ollama ai + # When `stream: true` is + # If `stream: false`, the completed message is returned. + # + # If `stream: true`, the completed message is returned after MessageDelta's. + # + # Retries the request up to 3 times on transient errors with a 1 second delay + @doc false + @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: + list() | struct() | {:error, String.t()} + def do_api_request(ollama_ai, messages, functions, callback_fn, retry_count \\ 3) + + def do_api_request(_ollama_ai, _messages, _functions, _callback_fn, 0) do + raise LangChainError, "Retries exceeded. Connection failed." + end + + def do_api_request(%ChatOllamaAI{stream: false} = ollama_ai, messages, functions, callback_fn, retry_count) do + req = + Req.new( + url: ollama_ai.endpoint, + json: for_api(ollama_ai, messages, functions), + receive_timeout: ollama_ai.receive_timeout, + retry: :transient, + max_retries: 3, + retry_delay: fn attempt -> 300 * attempt end + ) + + req + |> Req.post() + |> case do + {:ok, %Req.Response{body: data}} -> + case do_process_response(data) do + {:error, reason} -> + {:error, reason} + + result -> + result + end + + {:error, %Mint.TransportError{reason: :timeout}} -> + {:error, "Request timed out"} + + {:error, %Mint.TransportError{reason: :closed}} -> + # Force a retry by making a recursive call decrementing the counter + Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) + do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1) + + other -> + Logger.error("Unexpected and unhandled API response! #{inspect(other)}") + other + end + end + + def do_api_request(%ChatOllamaAI{stream: true} = ollama_ai, messages, functions, callback_fn, retry_count) do + Req.new( + url: ollama_ai.endpoint, + json: for_api(ollama_ai, messages, functions), + receive_timeout: ollama_ai.receive_timeout + ) + |> Req.post(into: Utils.handle_stream_fn(ollama_ai, &do_process_response/1, callback_fn)) + |> case do + {:ok, %Req.Response{body: data}} -> + data + + {:error, %LangChainError{message: reason}} -> + {:error, reason} + + {:error, %Mint.TransportError{reason: :timeout}} -> + {:error, "Request timed out"} + + {:error, %Mint.TransportError{reason: :closed}} -> + # Force a retry by making a recursive call decrementing the counter + Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end) + do_api_request(ollama_ai, messages, functions, callback_fn, retry_count - 1) + + other -> + Logger.error( + "Unhandled and unexpected response from streamed post call. #{inspect(other)}" + ) + + {:error, "Unexpected response"} + end + end + + def do_process_response(%{"message" => message, "done" => true}) do + create_message(message, :complete, Message) + end + + def do_process_response(%{"message" => message, "done" => _other}) do + create_message(message, :incomplete, MessageDelta) + end + + def do_process_response(%{"error" => reason}) do + Logger.error("Received error from API: #{inspect(reason)}") + {:error, reason} + end + + defp create_message(message, status, message_type) do + case message_type.new(Map.merge(message, %{"status" => status})) do + {:ok, new_message} -> + new_message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end +end diff --git a/test/chat_models/chat_ollama_ai_test.exs b/test/chat_models/chat_ollama_ai_test.exs new file mode 100644 index 0000000..91c5dd8 --- /dev/null +++ b/test/chat_models/chat_ollama_ai_test.exs @@ -0,0 +1,230 @@ +defmodule ChatModels.ChatOllamaAITest do + use LangChain.BaseCase + + doctest LangChain.ChatModels.ChatOllamaAI + alias LangChain.ChatModels.ChatOllamaAI + + describe "new/1" do + test "works with minimal attributes" do + assert {:ok, %ChatOllamaAI{} = ollama_ai} = ChatOllamaAI.new(%{"model" => "llama2:latest"}) + assert ollama_ai.model == "llama2:latest" + assert ollama_ai.endpoint == "http://localhost:11434/api/chat" + end + + test "returns errors given invalid attributes" do + assert {:error, changeset} = + ChatOllamaAI.new(%{"model" => nil, "temperature" => 4.4, "mirostat_eta" => 4.4}) + + refute changeset.valid? + assert {"can't be blank", _} = changeset.errors[:model] + assert {"must be less than or equal to %{number}", _} = changeset.errors[:temperature] + assert {"must be less than or equal to %{number}", _} = changeset.errors[:mirostat_eta] + end + + test "supports overriding the API endpoint" do + override_url = "http://localhost:99999/api/chat" + + model = + ChatOllamaAI.new!(%{ + endpoint: override_url + }) + + assert model.endpoint == override_url + end + end + + describe "for_api/3" do + setup do + {:ok, ollama_ai} = + ChatOllamaAI.new(%{ + "model" => "llama2:latest", + "temperature" => 0.4, + "stream" => false, + "seed" => 0, + "num_ctx" => 2048, + "num_predict" => 128, + "repeat_last_n" => 64, + "repeat_penalty" => 1.1, + "mirostat" => 0, + "mirostat_eta" => 0.1, + "mirostat_tau" => 5.0, + "num_gqa" => 8, + "num_gpu" => 1, + "num_thread" => 0, + "receive_timeout" => 300_000, + "stop" => "", + "tfs_z" => 0.0, + "top_k" => 0, + "top_p" => 0.0 + }) + + %{ollama_ai: ollama_ai} + end + + test "generates a map for an API call with no messages", %{ollama_ai: ollama_ai} do + data = ChatOllamaAI.for_api(ollama_ai, [], []) + assert data.model == "llama2:latest" + assert data.temperature == 0.4 + assert data.stream == false + assert data.messages == [] + assert data.seed == 0 + assert data.num_ctx == 2048 + assert data.num_predict == 128 + assert data.repeat_last_n == 64 + assert data.repeat_penalty == 1.1 + assert data.mirostat == 0 + assert data.mirostat_eta == 0.1 + assert data.mirostat_tau == 5.0 + assert data.num_gqa == 8 + assert data.num_gpu == 1 + assert data.num_thread == 0 + assert data.receive_timeout == 300_000 + # TODO: figure out why this is field is is being cast to nil instead of empty string + assert data.stop == nil + assert data.tfs_z == 0.0 + assert data.top_k == 0 + assert data.top_p == 0.0 + end + + test "generates a map for an API call with a single message", %{ollama_ai: ollama_ai} do + user_message = "What color is the sky?" + + data = ChatOllamaAI.for_api(ollama_ai, [Message.new_user!(user_message)], []) + assert data.model == "llama2:latest" + assert data.temperature == 0.4 + + assert [%{"content" => "What color is the sky?", "role" => :user}] = data.messages + end + + test "generates a map for an API call with user and system messages", %{ollama_ai: ollama_ai} do + user_message = "What color is the sky?" + system_message = "You are a weather man" + + data = + ChatOllamaAI.for_api( + ollama_ai, + [Message.new_system!(system_message), Message.new_user!(user_message)], + [] + ) + + assert data.model == "llama2:latest" + assert data.temperature == 0.4 + + assert [ + %{"role" => :system} = system_msg, + %{"role" => :user} = user_msg + ] = data.messages + + assert system_msg["content"] == "You are a weather man" + assert user_msg["content"] == "What color is the sky?" + end + end + + describe "call/2" do + @tag :live_call_ollama_ai + test "basic content example with no streaming" do + {:ok, chat} = + ChatOllamaAI.new(%{ + model: "llama2:latest", + temperature: 1, + seed: 0, + stream: false + }) + + {:ok, %Message{role: :assistant, content: response}} = + ChatOllamaAI.call(chat, [ + Message.new_user!("Return the response 'Colorful Threads'.") + ]) + + assert response =~ "Colorful Threads" + end + + @tag :live_call_ollama_ai + test "basic content example with streaming" do + {:ok, chat} = + ChatOllamaAI.new(%{ + model: "llama2:latest", + temperature: 1, + seed: 0, + stream: true + }) + + result = + ChatOllamaAI.call(chat, [ + Message.new_user!("Return the response 'Colorful Threads'.") + ]) + + assert {:ok, deltas} = result + assert length(deltas) > 0 + + deltas_except_last = Enum.slice(deltas, 0..-2) + + for delta <- deltas_except_last do + assert delta.__struct__ == LangChain.MessageDelta + assert is_binary(delta.content) + assert delta.status == :incomplete + assert delta.role == :assistant + end + + last_delta = Enum.at(deltas, -1) + assert last_delta.__struct__ == LangChain.Message + assert is_nil(last_delta.content) + assert last_delta.status == :complete + assert last_delta.role == :assistant + end + + @tag :live_call_ollama_ai + test "returns an error when given an invalid payload" do + invalid_model = "invalid" + + {:error, reason} = + ChatOllamaAI.call(%ChatOllamaAI{model: invalid_model}, [ + Message.new_user!("Return the response 'Colorful Threads'.") + ]) + + assert reason == "model '#{invalid_model}' not found, try pulling it first" + end + end + + describe "do_process_response/1" do + test "handles receiving a non streamed message result" do + response = %{ + "model" => "llama2", + "created_at" => "2024-01-15T23:02:24.087444Z", + "message" => %{ + "role" => "assistant", + "content" => "Greetings!" + }, + "done" => true, + "total_duration" => 12_323_379_834, + "load_duration" => 6_889_264_834, + "prompt_eval_count" => 26, + "prompt_eval_duration" => 91_493_000, + "eval_count" => 362, + "eval_duration" => 5_336_241_000 + } + + assert %Message{} = struct = ChatOllamaAI.do_process_response(response) + assert struct.role == :assistant + assert struct.content == "Greetings!" + assert struct.index == nil + end + + test "handles receiving a streamed message result" do + response = %{ + "model" => "llama2", + "created_at" => "2024-01-15T23:02:24.087444Z", + "message" => %{ + "role" => "assistant", + "content" => "Gre" + }, + "done" => false + } + + assert %MessageDelta{} = struct = ChatOllamaAI.do_process_response(response) + assert struct.role == :assistant + assert struct.content == "Gre" + assert struct.status == :incomplete + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index 19cb0b5..de03f9d 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1,5 +1,5 @@ # Load the ENV key for running live OpenAI tests. Application.put_env(:langchain, :openai_key, System.fetch_env!("OPENAI_API_KEY")) -ExUnit.configure(capture_log: true, exclude: [live_call: true]) +ExUnit.configure(capture_log: true, exclude: [live_call: true, live_call_ollama_ai: true]) ExUnit.start()