Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.NET] Add tools for Ollama #3373

Draft
wants to merge 3 commits into
base: 0.2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions dotnet/sample/AutoGen.Ollama.Sample/Create_Ollama_Agent_With_Tool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Create_Ollama_Agent_With_Tool.cs

using AutoGen.Core;
using AutoGen.Ollama.Extension;
using FluentAssertions;

namespace AutoGen.Ollama.Sample;

#region WeatherFunction
public partial class WeatherFunction
{
/// <summary>
/// Gets the weather based on the location and the unit
/// </summary>
/// <param name="location"></param>
/// <param name="unit"></param>
/// <returns></returns>
[Function]
public async Task<string> GetWeather(string location, string unit)
{
// dummy implementation
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
}
}
#endregion

public class Create_Ollama_Agent_With_Tool
{
public static async Task RunAsync()
{
#region define_tool
var tool = new Tool()
{
Function = new Function
{
Name = "get_current_weather",
Description = "Get the current weather for a location",
Parameters = new Parameters
{
Properties = new Dictionary<string, Properties>
{
{
"location",
new Properties
{
Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA"
}
},
{
"format", new Properties
{
Type = "string",
Description =
"The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
Enum = new List<string> {"celsius", "fahrenheit"}
}
}
},
Required = new List<string> { "location", "format" }
}
}
};

var weatherFunction = new WeatherFunction();
var functionMiddleware = new FunctionCallMiddleware(
functions: [
weatherFunction.GetWeatherFunctionContract,
],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
});

#endregion

#region create_ollama_agent_llama3.1

var agent = new OllamaAgent(
new HttpClient { BaseAddress = new Uri("http://localhost:11434") },
"MyAgent",
"llama3.1",
tools: [tool]);
#endregion

// TODO cannot stream
#region register_middleware
var agentWithConnector = agent
.RegisterMessageConnector()
.RegisterPrintMessage()
.RegisterStreamingMiddleware(functionMiddleware);
#endregion register_middleware

#region single_turn
var question = new TextMessage(Role.Assistant,
"What is the weather like in San Francisco?",
from: "user");
var functionCallReply = await agentWithConnector.SendAsync(question);
#endregion

#region Single_turn_verify_reply
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
#endregion Single_turn_verify_reply

#region Multi_turn
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
#endregion Multi_turn

#region Multi_turn_verify_reply
finalReply.Should().BeOfType<TextMessage>();
#endregion Multi_turn_verify_reply
}
}
2 changes: 1 addition & 1 deletion dotnet/sample/AutoGen.Ollama.Sample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

using AutoGen.Ollama.Sample;

await Chat_With_LLaVA.RunAsync();
await Create_Ollama_Agent_With_Tool.RunAsync();
12 changes: 10 additions & 2 deletions dotnet/src/AutoGen.Ollama/Agent/OllamaAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,23 @@ public class OllamaAgent : IStreamingAgent
private readonly string _modelName;
private readonly string _systemMessage;
private readonly OllamaReplyOptions? _replyOptions;
private readonly Tool[]? _tools;

public OllamaAgent(HttpClient httpClient, string name, string modelName,
string systemMessage = "You are a helpful AI assistant",
OllamaReplyOptions? replyOptions = null)
OllamaReplyOptions? replyOptions = null, Tool[]? tools = null)
{
Name = name;
_httpClient = httpClient;
_modelName = modelName;
_systemMessage = systemMessage;
_replyOptions = replyOptions;
_tools = tools;

if (_httpClient.BaseAddress == null)
{
throw new InvalidOperationException($"Please add the base address to httpClient");
}
}

public async Task<IMessage> GenerateReplyAsync(
Expand Down Expand Up @@ -97,7 +104,8 @@ private async Task<ChatRequest> BuildChatRequest(IEnumerable<IMessage> messages,
var request = new ChatRequest
{
Model = _modelName,
Messages = await BuildChatHistory(messages)
Messages = await BuildChatHistory(messages),
Tools = _tools
};

if (options is OllamaReplyOptions replyOptions)
Expand Down
7 changes: 7 additions & 0 deletions dotnet/src/AutoGen.Ollama/DTOs/ChatRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,11 @@ public class ChatRequest
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }

/// <summary>
/// Tools for the model to use. Not all models currently support tools.
/// Requires stream to be set to false
/// </summary>
[JsonPropertyName("tools")]
public IEnumerable<Tool>? Tools { get; set; }
}
27 changes: 25 additions & 2 deletions dotnet/src/AutoGen.Ollama/DTOs/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public Message()
{
}

public Message(string role, string value)
public Message(string role, string? value = null)
{
Role = role;
Value = value;
Expand All @@ -27,11 +27,34 @@ public Message(string role, string value)
/// the content of the message
/// </summary>
[JsonPropertyName("content")]
public string Value { get; set; } = string.Empty;
public string? Value { get; set; }

/// <summary>
/// (optional): a list of images to include in the message (for multimodal models such as llava)
/// </summary>
[JsonPropertyName("images")]
public IList<string>? Images { get; set; }

/// <summary>
/// A list of tools the model wants to use. Not all models currently support tools.
/// Tool call is not supported while streaming.
/// </summary>
[JsonPropertyName("tool_calls")]
public IEnumerable<ToolCall>? ToolCalls { get; set; }

public class ToolCall
{
[JsonPropertyName("function")]
public Function? Function { get; set; }
}

public class Function
{
[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("arguments")]
public Dictionary<string, string>? Arguments { get; set; }
}
}

52 changes: 52 additions & 0 deletions dotnet/src/AutoGen.Ollama/DTOs/Tools.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Tools.cs

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace AutoGen.Ollama;

public class Tool
{
[JsonPropertyName("type")]
public string? Type { get; set; } = "function";

[JsonPropertyName("function")]
public Function? Function { get; set; }
}

public class Function
{
[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("description")]
public string? Description { get; set; }

[JsonPropertyName("parameters")]
public Parameters? Parameters { get; set; }
}

public class Parameters
{
[JsonPropertyName("type")]
public string? Type { get; set; } = "object";

[JsonPropertyName("properties")]
public Dictionary<string, Properties>? Properties { get; set; }

[JsonPropertyName("required")]
public IEnumerable<string>? Required { get; set; }
}

public class Properties
{
[JsonPropertyName("type")]
public string? Type { get; set; }

[JsonPropertyName("description")]
public string? Description { get; set; }

[JsonPropertyName("enum")]
public IEnumerable<string>? Enum { get; set; }
}
83 changes: 73 additions & 10 deletions dotnet/src/AutoGen.Ollama/Middlewares/OllamaMessageConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
Expand All @@ -24,6 +25,7 @@ public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent,

return reply switch
{
IMessage<ChatResponse> { Content.Message.ToolCalls: not null } messageEnvelope when messageEnvelope.Content.Message.ToolCalls.Any() => ProcessToolCalls(messageEnvelope, agent),
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is string content => new TextMessage(Role.Assistant, content, messageEnvelope.From),
IMessage<ChatResponse> messageEnvelope when messageEnvelope.Content.Message?.Value is null => throw new InvalidOperationException("Message content is null"),
_ => reply
Expand Down Expand Up @@ -73,20 +75,21 @@ private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAg
{
return messages.SelectMany(m =>
{
if (m is IMessage<Message> messageEnvelope)
if (m is IMessage<Message>)
{
return [m];
}
else

return m switch
{
return m switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
_ => [m],
};
}
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ImageMessage imageMessage => ProcessImageMessage(imageMessage, agent),
ToolCallMessage toolCallMessage => ProcessToolCallMessage(toolCallMessage, agent),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage => ProcessToolCallAggregateMessage(toolCallAggregateMessage, agent),
MultiModalMessage multiModalMessage => ProcessMultiModalMessage(multiModalMessage, agent),
_ => [m],
};
});
}

Expand Down Expand Up @@ -183,4 +186,64 @@ private IEnumerable<IMessage> ProcessTextMessage(TextMessage textMessage, IAgent
return [MessageEnvelope.Create(message, agent.Name)];
}
}

private IMessage ProcessToolCalls(IMessage<ChatResponse> messageEnvelope, IAgent agent)
{
var toolCalls = new List<ToolCall>();
foreach (var messageToolCall in messageEnvelope.Content.Message?.ToolCalls!)
{
toolCalls.Add(new ToolCall(
messageToolCall.Function?.Name ?? string.Empty,
JsonSerializer.Serialize(messageToolCall.Function?.Arguments)));
}

return new ToolCallMessage(toolCalls, agent.Name) { Content = messageEnvelope.Content.Message.Value };
}

private IEnumerable<IMessage> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
{
var chatMessage = new Message(toolCallMessage.From ?? string.Empty, toolCallMessage.GetContent())
{
ToolCalls = toolCallMessage.ToolCalls.Select(t => new Message.ToolCall
{
Function = new Message.Function
{
Name = t.FunctionName,
Arguments = JsonSerializer.Deserialize<Dictionary<string, string>>(t.FunctionArguments),
},
}),
};

return [MessageEnvelope.Create(chatMessage, toolCallMessage.From)];
}

private IEnumerable<IMessage> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage)
{
foreach (var toolCall in toolCallResultMessage.ToolCalls)
{
if (!string.IsNullOrEmpty(toolCall.Result))
{
return [MessageEnvelope.Create(new Message("tool", toolCall.Result), toolCallResultMessage.From)];
}
}

throw new InvalidOperationException("Expected to have at least one tool call result");
}

private IEnumerable<IMessage> ProcessToolCallAggregateMessage(AggregateMessage<ToolCallMessage, ToolCallResultMessage> toolCallAggregateMessage, IAgent agent)
{
if (toolCallAggregateMessage.From is { } from && from != agent.Name)
{
var contents = toolCallAggregateMessage.Message2.ToolCalls.Select(t => t.Result);
var messages =
contents.Select(c => new Message("assistant", c ?? throw new ArgumentNullException(nameof(c))));

return messages.Select(m => new MessageEnvelope<Message>(m, from: from));
}

var toolCallMessage = ProcessToolCallMessage(toolCallAggregateMessage.Message1, agent);
var toolCallResult = ProcessToolCallResultMessage(toolCallAggregateMessage.Message2);

return toolCallMessage.Concat(toolCallResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ public async Task AnthropicAgentFunctionCallMessageTest()
)
.RegisterMessageConnector();

var weatherFunctionArgumets = """
var weatherFunctionArguments = """
{
"city": "Philadelphia",
"date": "6/14/2024"
}
""";

var function = new AnthropicTestFunctionCalls();
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArgumets);
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArgumets)
var functionCallResult = await function.GetWeatherReportWrapper(weatherFunctionArguments);
var toolCall = new ToolCall(function.WeatherReportFunctionContract.Name!, weatherFunctionArguments)
{
ToolCallId = "get_weather",
Result = functionCallResult,
Expand Down
Loading
Loading