diff --git a/lib/chat_models/chat_vertex_ai.ex b/lib/chat_models/chat_vertex_ai.ex new file mode 100644 index 0000000..28799cf --- /dev/null +++ b/lib/chat_models/chat_vertex_ai.ex @@ -0,0 +1,586 @@ +defmodule LangChain.ChatModels.ChatVertexAI do + @moduledoc """ + Parses and validates inputs for making a request for the Google AI Chat API. + + Converts response into more specialized `LangChain` data structures. + """ + use Ecto.Schema + require Logger + import Ecto.Changeset + alias __MODULE__ + alias LangChain.Config + alias LangChain.ChatModels.ChatModel + alias LangChain.ChatModels.ChatOpenAI + alias LangChain.Message + alias LangChain.MessageDelta + alias LangChain.Message.ContentPart + alias LangChain.Message.ToolCall + alias LangChain.Message.ToolResult + alias LangChain.LangChainError + alias LangChain.Utils + + @behaviour ChatModel + + # allow up to 2 minutes for response. + @receive_timeout 60_000 + + @primary_key false + embedded_schema do + field :endpoint, :string + + field :model, :string, default: "gemini-pro" + field :api_key, :string + + # What sampling temperature to use, between 0 and 2. Higher values like 0.8 + # will make the output more random, while lower values like 0.2 will make it + # more focused and deterministic. + field :temperature, :float, default: 0.9 + + # The topP parameter changes how the model selects tokens for output. Tokens + # are selected from the most to least probable until the sum of their + # probabilities equals the topP value. For example, if tokens A, B, and C have + # a probability of 0.3, 0.2, and 0.1 and the topP value is 0.5, then the model + # will select either A or B as the next token by using the temperature and exclude + # C as a candidate. The default topP value is 0.95. + field :top_p, :float, default: 1.0 + + # The topK parameter changes how the model selects tokens for output. A topK of + # 1 means the selected token is the most probable among all the tokens in the + # model's vocabulary (also called greedy decoding), while a topK of 3 means that + # the next token is selected from among the 3 most probable using the temperature. + # For each token selection step, the topK tokens with the highest probabilities + # are sampled. Tokens are then further filtered based on topP with the final token + # selected using temperature sampling. + field :top_k, :float, default: 1.0 + + # Duration in seconds for the response to be received. When streaming a very + # lengthy response, a longer time limit may be required. However, when it + # goes on too long by itself, it tends to hallucinate more. + field :receive_timeout, :integer, default: @receive_timeout + + field :stream, :boolean, default: false + field :json_response, :boolean, default: false + end + + @type t :: %ChatVertexAI{} + + @create_fields [ + :endpoint, + :model, + :api_key, + :temperature, + :top_p, + :top_k, + :receive_timeout, + :stream, + :json_response + ] + @required_fields [ + :endpoint, + :model + ] + + @spec get_api_key(t) :: String.t() + defp get_api_key(%ChatVertexAI{api_key: api_key}) do + # if no API key is set default to `""` which will raise an API error + api_key || Config.resolve(:vertex_ai_key, "") + end + + @doc """ + Setup a ChatVertexAI client configuration. + """ + @spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()} + def new(%{} = attrs \\ %{}) do + %ChatVertexAI{} + |> cast(attrs, @create_fields) + |> common_validation() + |> apply_action(:insert) + end + + @doc """ + Setup a ChatVertexAI client configuration and return it or raise an error if invalid. + """ + @spec new!(attrs :: map()) :: t() | no_return() + def new!(attrs \\ %{}) do + case new(attrs) do + {:ok, chain} -> + chain + + {:error, changeset} -> + raise LangChainError, changeset + end + end + + defp common_validation(changeset) do + changeset + |> validate_required(@required_fields) + end + + def for_api(%ChatVertexAI{} = vertex_ai, messages, functions) do + messages_for_api = + messages + |> Enum.map(&for_api/1) + |> List.flatten() + |> List.wrap() + + req = %{ + "contents" => messages_for_api, + "generationConfig" => %{ + "temperature" => vertex_ai.temperature, + "topP" => vertex_ai.top_p, + "topK" => vertex_ai.top_k + } + } + + req = + if vertex_ai.json_response do + req + |> put_in(["generationConfig", "response_mime_type"], "application/json") + else + req + end + + if functions && not Enum.empty?(functions) do + req + |> Map.put("tools", [ + %{ + # Google AI functions use an OpenAI compatible format. + # See: https://ai.google.dev/docs/function_calling#how_it_works + "functionDeclarations" => Enum.map(functions, &ChatOpenAI.for_api/1) + } + ]) + else + req + end + end + + defp for_api(%Message{role: :assistant} = message) do + content_parts = get_message_contents(message) || [] + tool_calls = Enum.map(message.tool_calls || [], &for_api/1) + + %{ + "role" => map_role(:assistant), + "parts" => content_parts ++ tool_calls + } + end + + defp for_api(%Message{role: :tool} = message) do + %{ + "role" => map_role(:tool), + "parts" => Enum.map(message.tool_results, &for_api/1) + } + end + + defp for_api(%Message{role: :system} = message) do + # No system messages support means we need to fake a prompt and response + # to pretend like it worked. + [ + %{ + "role" => :user, + "parts" => [%{"text" => message.content}] + }, + %{ + "role" => :model, + "parts" => [%{"text" => ""}] + } + ] + end + + defp for_api(%Message{role: :user, content: content}) when is_list(content) do + %{ + "role" => "user", + "parts" => Enum.map(content, &for_api(&1)) + } + end + + defp for_api(%Message{} = message) do + %{ + "role" => map_role(message.role), + "parts" => [%{"text" => message.content}] + } + end + + defp for_api(%ContentPart{type: :text} = part) do + %{"text" => part.content} + end + + defp for_api(%ContentPart{type: :image} = part) do + %{ + "inlineData" => %{ + "mimeType" => Keyword.fetch!(part.options, :media), + "data" => part.content + } + } + end + + defp for_api(%ContentPart{type: :image_url} = part) do + %{ + "fileData" => %{ + "mimeType" => Keyword.fetch!(part.options, :media), + "data" => part.content + } + } + end + + defp for_api(%ToolCall{} = call) do + %{ + "functionCall" => %{ + "args" => call.arguments, + "name" => call.name + } + } + end + + defp for_api(%ToolResult{} = result) do + %{ + "functionResponse" => %{ + "name" => result.name, + "response" => Jason.decode!(result.content) + } + } + end + + @doc """ + Calls the Google AI API passing the ChatVertexAI struct with configuration, plus + either a simple message or the list of messages to act as the prompt. + + Optionally pass in a list of tools available to the LLM for requesting + execution in response. + + Optionally pass in a callback function that can be executed as data is + received from the API. + + **NOTE:** This function *can* be used directly, but the primary interface + should be through `LangChain.Chains.LLMChain`. The `ChatVertexAI` module is more focused on + translating the `LangChain` data structures to and from the OpenAI API. + + Another benefit of using `LangChain.Chains.LLMChain` is that it combines the + storage of messages, adding tools, adding custom context that should be + passed to tools, and automatically applying `LangChain.MessageDelta` + structs as they are are received, then converting those to the full + `LangChain.Message` once fully complete. + """ + @impl ChatModel + def call(openai, prompt, tools \\ [], callback_fn \\ nil) + + def call(%ChatVertexAI{} = vertex_ai, prompt, tools, callback_fn) when is_binary(prompt) do + messages = [ + Message.new_system!(), + Message.new_user!(prompt) + ] + + call(vertex_ai, messages, tools, callback_fn) + end + + def call(%ChatVertexAI{} = vertex_ai, messages, tools, callback_fn) + when is_list(messages) do + try do + case do_api_request(vertex_ai, messages, tools, callback_fn) do + {:error, reason} -> + {:error, reason} + + parsed_data -> + {:ok, parsed_data} + end + rescue + err in LangChainError -> + {:error, err.message} + end + end + + @doc false + @spec do_api_request(t(), [Message.t()], [Function.t()], (any() -> any())) :: + list() | struct() | {:error, String.t()} + def do_api_request(%ChatVertexAI{stream: false} = vertex_ai, messages, tools, callback_fn) do + req = + Req.new( + url: build_url(vertex_ai), + json: for_api(vertex_ai, messages, tools), + receive_timeout: vertex_ai.receive_timeout, + retry: :transient, + max_retries: 3, + auth: {:bearer, get_api_key(vertex_ai)}, + retry_delay: fn attempt -> 300 * attempt end + ) + + req + |> Req.post() + |> case do + {:ok, %Req.Response{body: data}} -> + case do_process_response(data) do + {:error, reason} -> + {:error, reason} + + result -> + Utils.fire_callback(vertex_ai, result, callback_fn) + result + end + + {:error, %Mint.TransportError{reason: :timeout}} -> + {:error, "Request timed out"} + + other -> + Logger.error("Unexpected and unhandled API response! #{inspect(other)}") + other + end + end + + def do_api_request(%ChatVertexAI{stream: true} = vertex_ai, messages, tools, callback_fn) do + Req.new( + url: build_url(vertex_ai), + json: for_api(vertex_ai, messages, tools), + auth: {:bearer, get_api_key(vertex_ai)}, + receive_timeout: vertex_ai.receive_timeout + ) + |> Req.Request.put_header("accept-encoding", "utf-8") + |> Req.post( + into: + Utils.handle_stream_fn( + vertex_ai, + &ChatOpenAI.decode_stream/1, + &do_process_response(&1, MessageDelta), + callback_fn + ) + ) + |> case do + {:ok, %Req.Response{body: data}} -> + # Google AI uses `finishReason: "STOP` for all messages in the stream. + # This field can't be used to terminate the list of deltas, so simulate + # this behavior by forcing the final delta to have `status: :complete`. + complete_final_delta(data) + + {:error, %LangChainError{message: reason}} -> + {:error, reason} + + {:error, %Mint.TransportError{reason: :timeout}} -> + {:error, "Request timed out"} + + other -> + Logger.error( + "Unhandled and unexpected response from streamed post call. #{inspect(other)}" + ) + + {:error, "Unexpected response"} + end + end + + @spec build_url(t()) :: String.t() + defp build_url(%ChatVertexAI{endpoint: endpoint, model: model} = vertex_ai) do + "#{endpoint}/models/#{model}:#{get_action(vertex_ai)}?key=#{get_api_key(vertex_ai)}" + |> use_sse(vertex_ai) + end + + @spec use_sse(String.t(), t()) :: String.t() + defp use_sse(url, %ChatVertexAI{stream: true}), do: url <> "&alt=sse" + defp use_sse(url, _model), do: url + + @spec get_action(t()) :: String.t() + defp get_action(%ChatVertexAI{stream: false}), do: "generateContent" + defp get_action(%ChatVertexAI{stream: true}), do: "streamGenerateContent" + + def complete_final_delta(data) when is_list(data) do + update_in(data, [Access.at(-1), Access.at(-1)], &%{&1 | status: :complete}) + end + + def do_process_response(response, message_type \\ Message) + + def do_process_response(%{"candidates" => candidates}, message_type) when is_list(candidates) do + candidates + |> Enum.map(&do_process_response(&1, message_type)) + end + + def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, Message) do + text_part = + parts + |> filter_parts_for_types(["text"]) + |> Enum.map(fn part -> + ContentPart.new!(%{type: :text, content: part["text"]}) + end) + + tool_calls_from_parts = + parts + |> filter_parts_for_types(["functionCall"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + tool_result_from_parts = + parts + |> filter_parts_for_types(["functionResponse"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + %{ + role: unmap_role(content_data["role"]), + content: text_part, + complete: false, + index: data["index"] + } + |> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts) + |> Utils.conditionally_add_to_map(:tool_results, tool_result_from_parts) + |> Message.new() + |> case do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response(%{"content" => %{"parts" => parts} = content_data} = data, MessageDelta) do + text_content = + case parts do + [%{"text" => text}] -> + text + + _other -> + nil + end + + parts + |> filter_parts_for_types(["text"]) + |> Enum.map(fn part -> + ContentPart.new!(%{type: :text, content: part["text"]}) + end) + + tool_calls_from_parts = + parts + |> filter_parts_for_types(["functionCall"]) + |> Enum.map(fn part -> + do_process_response(part, nil) + end) + + %{ + role: unmap_role(content_data["role"]), + content: text_content, + complete: true, + index: data["index"] + } + |> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts) + |> MessageDelta.new() + |> case do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response(%{"functionCall" => %{"args" => raw_args, "name" => name}} = data, _) do + %{ + call_id: "call-#{name}", + name: name, + arguments: raw_args, + complete: true, + index: data["index"] + } + |> ToolCall.new() + |> case do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response( + %{ + "finishReason" => finish, + "content" => %{"parts" => parts, "role" => role}, + "index" => index + }, + message_type + ) + when is_list(parts) do + status = + case message_type do + MessageDelta -> + :incomplete + + Message -> + case finish do + "STOP" -> + :complete + + "SAFETY" -> + :complete + + other -> + Logger.warning("Unsupported finishReason in response. Reason: #{inspect(other)}") + nil + end + end + + content = Enum.map_join(parts, & &1["text"]) + + case message_type.new(%{ + "content" => content, + "role" => unmap_role(role), + "status" => status, + "index" => index + }) do + {:ok, message} -> + message + + {:error, changeset} -> + {:error, Utils.changeset_error_to_string(changeset)} + end + end + + def do_process_response(%{"error" => %{"message" => reason}}, _) do + Logger.error("Received error from API: #{inspect(reason)}") + {:error, reason} + end + + def do_process_response({:error, %Jason.DecodeError{} = response}, _) do + error_message = "Received invalid JSON: #{inspect(response)}" + Logger.error(error_message) + {:error, error_message} + end + + def do_process_response(other, _) do + Logger.error("Trying to process an unexpected response. #{inspect(other)}") + {:error, "Unexpected response"} + end + + @doc false + def filter_parts_for_types(parts, types) when is_list(parts) and is_list(types) do + Enum.filter(parts, fn p -> + Enum.any?(types, &Map.has_key?(p, &1)) + end) + end + + @doc """ + Return the content parts for the message. + """ + @spec get_message_contents(MessageDelta.t() | Message.t()) :: [%{String.t() => any()}] + def get_message_contents(%{content: content} = _message) when is_binary(content) do + [%{"text" => content}] + end + + def get_message_contents(%{content: contents} = _message) when is_list(contents) do + Enum.map(contents, &for_api/1) + end + + def get_message_contents(%{content: nil} = _message) do + nil + end + + defp map_role(role) do + case role do + :assistant -> :model + :tool -> :function + # System prompts are not supported yet. Google recommends using user prompt. + :system -> :user + role -> role + end + end + + defp unmap_role("model"), do: "assistant" + defp unmap_role("function"), do: "tool" + defp unmap_role(role), do: role +end diff --git a/test/chat_models/chat_vertex_ai_test.exs b/test/chat_models/chat_vertex_ai_test.exs new file mode 100644 index 0000000..e5a5acb --- /dev/null +++ b/test/chat_models/chat_vertex_ai_test.exs @@ -0,0 +1,324 @@ +defmodule ChatModels.ChatVertexAITest do + alias LangChain.ChatModels.ChatVertexAI + use LangChain.BaseCase + + doctest LangChain.ChatModels.ChatVertexAI + alias LangChain.ChatModels.ChatVertexAI + alias LangChain.Message + alias LangChain.Message.ContentPart + alias LangChain.Message.ToolCall + alias LangChain.Message.ToolResult + alias LangChain.MessageDelta + alias LangChain.Function + + setup do + {:ok, hello_world} = + Function.new(%{ + name: "hello_world", + description: "Give a hello world greeting.", + function: fn _args, _context -> {:ok, "Hello world!"} end + }) + + %{hello_world: hello_world} + end + + describe "new/1" do + test "works with minimal attr" do + assert {:ok, %ChatVertexAI{} = vertex_ai} = + ChatVertexAI.new(%{"model" => "gemini-pro", "endpoint" => "http://localhost:1234/"}) + + assert vertex_ai.model == "gemini-pro" + end + + test "returns error when invalid" do + assert {:error, changeset} = ChatVertexAI.new(%{"model" => nil}) + refute changeset.valid? + assert {"can't be blank", _} = changeset.errors[:model] + end + end + + describe "for_api/3" do + setup do + {:ok, vertex_ai} = + ChatVertexAI.new(%{ + "model" => "gemini-pro", + "endpoint" => "http://localhost:1234/", + "temperature" => 1.0, + "top_p" => 1.0, + "top_k" => 1.0 + }) + + %{vertex_ai: vertex_ai} + end + + test "generates a map for an API call", %{vertex_ai: vertex_ai} do + data = ChatVertexAI.for_api(vertex_ai, [], []) + assert %{"contents" => [], "generationConfig" => config} = data + assert %{"temperature" => 1.0, "topK" => 1.0, "topP" => 1.0} = config + end + + test "generates a map containing user and assistant messages", %{vertex_ai: vertex_ai} do + user_message = "Hello Assistant!" + assistant_message = "Hello User!" + + data = + ChatVertexAI.for_api( + vertex_ai, + [ + Message.new_user!(user_message), + Message.new_assistant!(assistant_message) + ], + [] + ) + + assert %{"contents" => [msg1, msg2]} = data + assert %{"role" => :user, "parts" => [%{"text" => ^user_message}]} = msg1 + assert %{"role" => :model, "parts" => [%{"text" => ^assistant_message}]} = msg2 + end + + test "generates a map containing function and function call messages", %{vertex_ai: vertex_ai} do + message = "Can you do an action for me?" + arguments = %{"args" => "data"} + function_result = %{"result" => "data"} + + data = + ChatVertexAI.for_api( + vertex_ai, + [ + Message.new_user!(message), + Message.new_assistant!(%{ + tool_calls: [ + ToolCall.new!(%{ + call_id: "call_123", + name: "userland_action", + arguments: Jason.encode!(arguments) + }) + ] + }), + Message.new_tool_result!(%{ + tool_results: [ + ToolResult.new!(%{ + tool_call_id: "call_123", + name: "userland_action", + content: Jason.encode!(function_result) + }) + ] + }) + ], + [] + ) + + assert %{"contents" => [msg1, msg2, msg3]} = data + assert %{"role" => :user, "parts" => [%{"text" => ^message}]} = msg1 + assert %{"role" => :model, "parts" => [tool_call]} = msg2 + assert %{"role" => :function, "parts" => [tool_result]} = msg3 + + assert %{ + "functionCall" => %{ + "args" => ^arguments, + "name" => "userland_action" + } + } = tool_call + + assert %{ + "functionResponse" => %{ + "name" => "userland_action", + "response" => ^function_result + } + } = tool_result + end + + test "expands system messages into two", %{vertex_ai: vertex_ai} do + message = "These are some instructions." + + data = ChatVertexAI.for_api(vertex_ai, [Message.new_system!(message)], []) + + assert %{"contents" => [msg1, msg2]} = data + assert %{"role" => :user, "parts" => [%{"text" => ^message}]} = msg1 + assert %{"role" => :model, "parts" => [%{"text" => ""}]} = msg2 + end + + test "generates a map containing function declarations", %{ + vertex_ai: vertex_ai, + hello_world: hello_world + } do + data = ChatVertexAI.for_api(vertex_ai, [], [hello_world]) + + assert %{"contents" => []} = data + assert %{"tools" => [tool_call]} = data + + assert %{ + "functionDeclarations" => [ + %{ + "name" => "hello_world", + "description" => "Give a hello world greeting.", + "parameters" => %{"properties" => %{}, "type" => "object"} + } + ] + } = tool_call + end + end + + describe "do_process_response/2" do + test "handles receiving a message" do + response = %{ + "candidates" => [ + %{ + "content" => %{"role" => "model", "parts" => [%{"text" => "Hello User!"}]}, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [%Message{} = struct] = ChatVertexAI.do_process_response(response) + assert struct.role == :assistant + [%ContentPart{type: :text, content: "Hello User!"}] = struct.content + assert struct.index == 0 + assert struct.status == :complete + end + + test "error if receiving non-text content" do + response = %{ + "candidates" => [ + %{ + "content" => %{"role" => "bad_role", "parts" => [%{"text" => "Hello user"}]}, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [{:error, error_string}] = ChatVertexAI.do_process_response(response) + assert error_string == "role: is invalid" + end + + test "handles receiving function calls" do + args = %{"args" => "data"} + + response = %{ + "candidates" => [ + %{ + "content" => %{ + "role" => "model", + "parts" => [%{"functionCall" => %{"args" => args, "name" => "hello_world"}}] + }, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [%Message{} = struct] = ChatVertexAI.do_process_response(response) + assert struct.role == :assistant + assert struct.index == 0 + [call] = struct.tool_calls + assert call.name == "hello_world" + assert call.arguments == args + end + + test "handles receiving MessageDeltas as well" do + response = %{ + "candidates" => [ + %{ + "content" => %{ + "role" => "model", + "parts" => [%{"text" => "This is the first part of a mes"}] + }, + "finishReason" => "STOP", + "index" => 0 + } + ] + } + + assert [%MessageDelta{} = struct] = ChatVertexAI.do_process_response(response, MessageDelta) + assert struct.role == :assistant + assert struct.content == "This is the first part of a mes" + assert struct.index == 0 + assert struct.status == :incomplete + end + + test "handles API error messages" do + response = %{ + "error" => %{ + "code" => 400, + "message" => "Invalid request", + "status" => "INVALID_ARGUMENT" + } + } + + assert {:error, error_string} = ChatVertexAI.do_process_response(response) + assert error_string == "Invalid request" + end + + test "handles Jason.DecodeError" do + response = {:error, %Jason.DecodeError{}} + + assert {:error, error_string} = ChatVertexAI.do_process_response(response) + assert "Received invalid JSON:" <> _ = error_string + end + + test "handles unexpected response with error" do + response = %{} + assert {:error, "Unexpected response"} = ChatVertexAI.do_process_response(response) + end + end + + describe "filter_parts_for_types/2" do + test "returns a single functionCall type" do + parts = [ + %{"text" => "I think I'll call this function."}, + %{ + "functionCall" => %{ + "args" => %{"args" => "data"}, + "name" => "userland_action" + } + } + ] + + assert [%{"text" => _}] = ChatVertexAI.filter_parts_for_types(parts, ["text"]) + + assert [%{"functionCall" => _}] = + ChatVertexAI.filter_parts_for_types(parts, ["functionCall"]) + end + + test "returns a set of types" do + parts = [ + %{"text" => "I think I'll call this function."}, + %{ + "functionCall" => %{ + "args" => %{"args" => "data"}, + "name" => "userland_action" + } + } + ] + + assert parts == ChatVertexAI.filter_parts_for_types(parts, ["text", "functionCall"]) + end + end + + describe "get_message_contents/1" do + test "returns basic text as a ContentPart" do + message = Message.new_user!("Howdy!") + + result = ChatVertexAI.get_message_contents(message) + + assert result == [%{"text" => "Howdy!"}] + end + + test "supports a list of ContentParts" do + message = + Message.new_user!([ + ContentPart.new!(%{type: :text, content: "Hello!"}), + ContentPart.new!(%{type: :text, content: "What's up?"}) + ]) + + result = ChatVertexAI.get_message_contents(message) + + assert result == [ + %{"text" => "Hello!"}, + %{"text" => "What's up?"} + ] + end + end +end