-
Notifications
You must be signed in to change notification settings - Fork 67
/
chat_model.ex
63 lines (51 loc) · 1.92 KB
/
chat_model.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
defmodule LangChain.ChatModels.ChatModel do
require Logger
alias LangChain.Message
alias LangChain.MessageDelta
alias LangChain.Function
alias LangChain.Utils
@type call_response ::
{:ok, Message.t() | [Message.t()] | [MessageDelta.t()]} | {:error, String.t()}
@type tool :: Function.t()
@type tools :: [tool()]
@type t :: Ecto.Schema.t()
@callback call(
t(),
String.t() | [Message.t()],
[LangChain.Function.t()]
) :: call_response()
@callback serialize_config(t()) :: %{String.t() => any()}
@callback restore_from_map(%{String.t() => any()}) :: {:ok, struct()} | {:error, String.t()}
@doc """
Add a `LangChain.ChatModels.LLMCallbacks` callback map to the ChatModel if
it includes the `:callback` key.
"""
@spec add_callback(%{optional(:callbacks) => nil | map()}, map()) :: map() | struct()
def add_callback(%_{callbacks: callbacks} = model, callback_map) do
existing_callbacks = callbacks || []
%{model | callbacks: existing_callbacks ++ [callback_map]}
end
def add_callback(model, _callback_map), do: model
@doc """
Create a serializable map from a ChatModel's current configuration that can
later be restored.
"""
def serialize_config(%chat_module{} = model) do
# plucks the module from the struct and, because of the behaviour, assumes
# the module defines a `serialize_config/1` function that is executed.
chat_module.serialize_config(model)
end
@doc """
Restore a ChatModel from a serialized config map.
"""
@spec restore_from_map(nil | %{String.t() => any()}) :: {:ok, struct()} | {:error, String.t()}
def restore_from_map(nil), do: {:error, "No data to restore"}
def restore_from_map(%{"module" => module_name} = data) do
case Utils.module_from_name(module_name) do
{:ok, module} ->
module.restore_from_map(data)
{:error, _reason} = error ->
error
end
end
end