Skip to content

Commit

Permalink
Merge pull request #171 from awaescher/feature/ToolGenerator
Browse files Browse the repository at this point in the history
Improve tool usage with source generators
  • Loading branch information
awaescher authored Feb 21, 2025
2 parents 39d42d6 + c6bebbf commit 5bf9efc
Show file tree
Hide file tree
Showing 60 changed files with 591 additions and 217 deletions.
8 changes: 7 additions & 1 deletion OllamaSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.5.2.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaSharp", "src\OllamaSharp.csproj", "{0194817F-1B9C-4D44-8960-1A0DC92D9D22}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaSharp", "src\OllamaSharp\OllamaSharp.csproj", "{0194817F-1B9C-4D44-8960-1A0DC92D9D22}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tests", "test\Tests.csproj", "{1527F300-40C7-49EB-A6FD-D21B20BA5BC1}"
EndProject
Expand All @@ -14,6 +14,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaApiConsole", "demo\OllamaApiConsole.csproj", "{755670DB-33A4-441A-99C2-642A04D08953}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OllamaSharp.SourceGenerators", "src\SourceGenerators\OllamaSharp.SourceGenerators.csproj", "{5E82D0E3-0A62-4D84-8ED0-5F928885D8FB}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand All @@ -32,6 +34,10 @@ Global
{755670DB-33A4-441A-99C2-642A04D08953}.Debug|Any CPU.Build.0 = Debug|Any CPU
{755670DB-33A4-441A-99C2-642A04D08953}.Release|Any CPU.ActiveCfg = Release|Any CPU
{755670DB-33A4-441A-99C2-642A04D08953}.Release|Any CPU.Build.0 = Release|Any CPU
{5E82D0E3-0A62-4D84-8ED0-5F928885D8FB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5E82D0E3-0A62-4D84-8ED0-5F928885D8FB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5E82D0E3-0A62-4D84-8ED0-5F928885D8FB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5E82D0E3-0A62-4D84-8ED0-5F928885D8FB}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
6 changes: 2 additions & 4 deletions demo/Demos/ModelManagerConsole.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ private async Task CopyModel()
private async Task CreateModel()
{
var createName = ReadInput("Enter a name for your new model:");
var fromModel = ReadInput("Enter the name of the model to create from:",
$"[{HintTextColor}]See [/][{AccentTextColor}][link]https://ollama.ai/library[/][/][{HintTextColor}] for available models[/]");
var systemPrompt = ReadInput("Set a new system prompt word for the model:");
await foreach (var status in Ollama.CreateModelAsync(new CreateModelRequest { From = fromModel, System = systemPrompt, Model = createName }))
var modelFileContent = ReadInput("Insert the model file content:", $"[{HintTextColor}]See [/][{AccentTextColor}][link]https://github.com/ollama/ollama/blob/main/docs/modelfile.md[/][/][{HintTextColor}] for available models[/]");
await foreach (var status in Ollama.CreateModelAsync(new CreateModelRequest { ModelFileContent = modelFileContent, Model = createName }))
AnsiConsole.MarkupLineInterpolated($"{status?.Status ?? ""}");
}

Expand Down
150 changes: 42 additions & 108 deletions demo/Demos/ToolConsole.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Reflection;
using OllamaSharp;
using OllamaSharp.Models.Chat;
using OllamaSharp.Models.Exceptions;
using Spectre.Console;

Expand Down Expand Up @@ -49,6 +47,8 @@ public override async Task Run()
break;
}

var currentMessageCount = chat.Messages.Count;

try
{
await foreach (var answerToken in chat.SendAsync(message, GetTools()))
Expand All @@ -59,32 +59,30 @@ public override async Task Run()
AnsiConsole.MarkupLineInterpolated($"[{ErrorTextColor}]{ex.Message}[/]");
}

var toolCalls = chat.Messages.LastOrDefault()?.ToolCalls?.ToArray() ?? [];
if (toolCalls.Any())
{
AnsiConsole.MarkupLine("\n[purple]Tools used:[/]");
// find the latest message from the assistant and possible tools
var newMessages = chat.Messages.Skip(currentMessageCount);

foreach (var function in toolCalls.Where(t => t.Function != null).Select(t => t.Function))
foreach (var newMessage in newMessages)
{
if (newMessage.ToolCalls?.Any() ?? false)
{
AnsiConsole.MarkupLineInterpolated($" - [purple]{function!.Name}[/]");

AnsiConsole.MarkupLineInterpolated($" - [purple]parameters[/]");

if (function?.Arguments is not null)
{
foreach (var argument in function.Arguments)
AnsiConsole.MarkupLineInterpolated($" - [purple]{argument.Key}[/]: [purple]{argument.Value}[/]");
}
AnsiConsole.MarkupLine("\n[purple]Tools used:[/]");

if (function is not null)
foreach (var function in newMessage.ToolCalls.Where(t => t.Function != null).Select(t => t.Function))
{
var result = FunctionHelper.ExecuteFunction(function);
AnsiConsole.MarkupLineInterpolated($" - [purple]return value[/]: [purple]\"{result}\"[/]");

await foreach (var answerToken in chat.SendAsAsync(ChatRole.Tool, result, GetTools()))
AnsiConsole.MarkupInterpolated($"[{AiTextColor}]{answerToken}[/]");
AnsiConsole.MarkupLineInterpolated($" - [purple]{function!.Name}[/]");
AnsiConsole.MarkupLineInterpolated($" - [purple]parameters[/]");

if (function?.Arguments is not null)
{
foreach (var argument in function.Arguments)
AnsiConsole.MarkupLineInterpolated($" - [purple]{argument.Key}[/]: [purple]{argument.Value}[/]");
}
}
}

if (newMessage.Role.GetValueOrDefault() == OllamaSharp.Models.Chat.ChatRole.Tool)
AnsiConsole.MarkupLineInterpolated($" [blue]-> \"{newMessage.Content}\"[/]");
}

AnsiConsole.WriteLine();
Expand All @@ -93,96 +91,32 @@ public override async Task Run()
}
}

private static Tool[] GetTools() => [new WeatherTool(), new NewsTool()];
private static object[] GetTools() => [new GetWeatherTool(), new GetLatLonAsyncTool()];

private sealed class WeatherTool : Tool
public enum Unit
{
public WeatherTool()
{
Function = new Function
{
Description = "Get the current weather for a location",
Name = "get_current_weather",
Parameters = new Parameters
{
Properties = new Dictionary<string, Property>
{
["location"] = new() { Type = "string", Description = "The location to get the weather for, e.g. San Francisco, CA" },
["format"] = new() { Type = "string", Description = "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'", Enum = ["celsius", "fahrenheit"] },
},
Required = ["location", "format"],
}
};
Type = "function";
}
Celsius,
Fahrenheit
}

private sealed class NewsTool : Tool
/// <summary>
/// Gets the current weather for a given location.
/// </summary>
/// <param name="location">The location or city to get the weather for</param>
/// <param name="unit">The unit to measure the temperature in</param>
/// <returns>The weather for the given location</returns>
[OllamaTool]
public static string GetWeather(string location, Unit unit) => $"It's cold at only 6° {unit} in {location}.";

/// <summary>
/// Gets the latitude and longitude for a given location.
/// </summary>
/// <param name="location">The location to get the latitude and longitude for</param>
/// <returns>The weather for the given location</returns>
[OllamaTool]
public async static Task<string> GetLatLonAsync(string location)
{
public NewsTool()
{
Function = new Function
{
Description = "Get the current news for a location",
Name = "get_current_news",
Parameters = new Parameters
{
Properties = new Dictionary<string, Property>
{
["location"] = new() { Type = "string", Description = "The location to get the news for, e.g. San Francisco, CA" },
["category"] = new() { Type = "string", Description = "The optional category to filter the news, can be left empty to return all.", Enum = ["politics", "economy", "sports", "entertainment", "health", "technology", "science"] },
},
Required = ["location"],
}
};
Type = "function";
}
}

private static class FunctionHelper
{
public static string ExecuteFunction(Message.Function function)
{
var toolFunction = _availableFunctions[function.Name!];
var parameters = MapParameters(toolFunction.Method, function.Arguments!);
return toolFunction.DynamicInvoke(parameters)?.ToString()!;
}

private static readonly Dictionary<string, Func<string, string?, string>> _availableFunctions = new()
{
["get_current_weather"] = (location, format) =>
{
var (temperature, unit) = format switch
{
"fahrenheit" => (Random.Shared.Next(23, 104), "°F"),
_ => (Random.Shared.Next(-5, 40), "°C"),
};

return $"{temperature} {unit} in {location}";
},
["get_current_news"] = (location, category) =>
{
category = string.IsNullOrEmpty(category) ? "all" : category;
return $"Could not find news for {location} (category: {category}).";
}
};

private static object[] MapParameters(MethodBase method, IDictionary<string, object> namedParameters)
{
var paramNames = method.GetParameters().Select(p => p.Name).ToArray();
var parameters = new object[paramNames.Length];

for (var i = 0; i < parameters.Length; ++i)
parameters[i] = Type.Missing;

foreach (var (paramName, value) in namedParameters)
{
var paramIndex = Array.IndexOf(paramNames, paramName);
if (paramIndex >= 0)
parameters[paramIndex] = value?.ToString() ?? "";
}

return parameters;
}
await Task.Delay(1000).ConfigureAwait(false);
return $"{new Random().Next(20, 50)}.4711, {new Random().Next(3, 15)}.0815";
}
}
4 changes: 3 additions & 1 deletion demo/OllamaApiConsole.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\src\OllamaSharp.csproj" />
<ProjectReference Include="..\src\OllamaSharp\OllamaSharp.csproj" />
<ProjectReference Include="..\src\SourceGenerators\OllamaSharp.SourceGenerators.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>

</Project>

File renamed without changes.
File renamed without changes.
37 changes: 26 additions & 11 deletions src/Chat.cs → src/OllamaSharp/Chat.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Runtime.CompilerServices;
using OllamaSharp.Models;
using OllamaSharp.Models.Chat;
using OllamaSharp.Tools;

namespace OllamaSharp;

Expand Down Expand Up @@ -55,6 +56,11 @@ public class Chat
/// </summary>
public RequestOptions? Options { get; set; }

/// <summary>
/// Gets or sets the class instance that invokes provided tools requested by the AI model
/// </summary>
public IToolInvoker ToolInvoker { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="Chat"/> class.
/// This basic constructor sets up the chat without a predefined system prompt.
Expand All @@ -79,6 +85,9 @@ public Chat(IOllamaApiClient client)
{
Client = client ?? throw new ArgumentNullException(nameof(client));
Model = Client.SelectedModel;

// continues the conversation with automatically sending messages (role=tool) from the results of the tools into the chat
ToolInvoker = new ChatContinuingToolInvoker(this);
}

/// <summary>
Expand Down Expand Up @@ -226,7 +235,7 @@ public IAsyncEnumerable<string> SendAsync(string message, IEnumerable<string>? i
/// }
/// </code>
/// </example>
public IAsyncEnumerable<string> SendAsync(string message, IEnumerable<Tool>? tools,
public IAsyncEnumerable<string> SendAsync(string message, IEnumerable<object>? tools,
IEnumerable<string>? imagesAsBase64 = null, object? format = null,
CancellationToken cancellationToken = default)
=> SendAsAsync(ChatRole.User, message, tools: tools, imagesAsBase64: imagesAsBase64, format: format,
Expand Down Expand Up @@ -366,7 +375,7 @@ public IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message, IEnum
/// Thrown if the <paramref name="format"/> argument is of type <see cref="CancellationToken"/> by mistake, or if any unsupported types are passed.
/// </exception>
/// <example>
/// Using the <see cref="SendAsAsync(ChatRole, string, IEnumerable{Tool},IEnumerable{string})"/> method to send a message and stream the model's response:
/// Using the <see cref="SendAsAsync(ChatRole, string, IEnumerable{object}, IEnumerable{string}, object, CancellationToken)"/> method to send a message and stream the model's response:
/// <code>
/// var chat = new Chat(client);
/// var role = new ChatRole("assistant");
Expand All @@ -378,7 +387,7 @@ public IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message, IEnum
/// }
/// </code>
/// </example>
public async IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message, IEnumerable<Tool>? tools,
public async IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message, IEnumerable<object>? tools,
IEnumerable<string>? imagesAsBase64 = null, object? format = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Expand All @@ -388,13 +397,11 @@ public async IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message,

Messages.Add(new Message(role, message, imagesAsBase64?.ToArray()));

var hasTools = tools?.Any() ?? false;

var request = new ChatRequest
{
Messages = Messages,
Model = Model,
Stream = !hasTools, // cannot stream if tools should be used
Stream = true,
Tools = tools,
Format = format,
Options = Options
Expand All @@ -403,14 +410,22 @@ public async IAsyncEnumerable<string> SendAsAsync(ChatRole role, string message,
var messageBuilder = new MessageBuilder();
await foreach (var answer in Client.ChatAsync(request, cancellationToken).ConfigureAwait(false))
{
if (answer is null) continue;

if (answer is null)
continue;

messageBuilder.Append(answer);

yield return answer.Message.Content ?? string.Empty;
}

if (messageBuilder.HasValue)
Messages.Add(messageBuilder.ToMessage());
{
var answerMessage = messageBuilder.ToMessage();
Messages.Add(answerMessage);

// call tools if available and requested by the AI model and yield the results
await foreach (var answer2 in ToolInvoker.InvokeAsync(answerMessage.ToolCalls ?? [], tools ?? [], cancellationToken).ConfigureAwait(false))
yield return answer2;
}
}
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,9 @@ private static void TryAddOllamaOption<T>(ChatOptions? microsoftChatOptions, Oll
/// </summary>
/// <param name="tools">The tools to convert.</param>
/// <returns>An enumeration of <see cref="Tool"/> objects containing the converted data.</returns>
private static IEnumerable<Tool>? ToOllamaSharpTools(IEnumerable<AITool>? tools)
private static IEnumerable<object>? ToOllamaSharpTools(IEnumerable<AITool>? tools)
{
return tools?.Select(ToOllamaSharpTool)
.Where(t => t is not null)
.Cast<Tool>();
return tools?.Select(ToOllamaSharpTool).Where(t => t is not null);
}

/// <summary>
Expand All @@ -150,7 +148,7 @@ private static void TryAddOllamaOption<T>(ChatOptions? microsoftChatOptions, Oll
/// If parseable, a <see cref="Tool"/> object containing the converted data,
/// otherwise <see langword="null"/>.
/// </returns>
private static Tool? ToOllamaSharpTool(AITool tool)
private static object? ToOllamaSharpTool(AITool tool)
{
if (tool is AIFunction f)
return ToOllamaSharpTool(f);
Expand Down Expand Up @@ -241,7 +239,7 @@ private static IEnumerable<Message> ToOllamaSharpMessages(IList<ChatMessage> cha
}

/// <summary>
/// Converts a Microsoft.Extensions.AI.<see cref="ImageContent"/> to a base64 image string.
/// Converts a Microsoft.Extensions.AI.<see cref="DataContent"/> to a base64 image string.
/// </summary>
/// <param name="content">The data content to convert.</param>
/// <returns>A string containing the base64 image data.</returns>
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class ChatRequest : OllamaRequest
/// </summary>
[JsonPropertyName("tools")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IEnumerable<Tool>? Tools { get; set; }
public IEnumerable<object>? Tools { get; set; }
}

/// <summary>
Expand Down Expand Up @@ -155,4 +155,4 @@ public class Property
/// </summary>
[JsonPropertyName("enum")]
public IEnumerable<string>? Enum { get; set; }
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 5bf9efc

Please sign in to comment.