From 545500b32967d9806b15899d3ed3d125a5545f30 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 5 Oct 2024 08:46:57 +0900 Subject: [PATCH] FEATURE: allows forced LLM tool use (#818) * FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review --- .../admin/ai_personas_controller.rb | 9 +- app/models/ai_persona.rb | 18 ++-- .../discourse/admin/models/ai-persona.js | 20 +++-- .../components/ai-persona-editor.gjs | 51 ++++++++++- .../discourse/components/ai-tool-selector.js | 2 +- config/locales/client.en.yml | 1 + lib/ai_bot/bot.rb | 14 +++ lib/ai_bot/personas/persona.rb | 4 + lib/completions/dialects/dialect.rb | 4 + lib/completions/endpoints/open_ai.rb | 8 +- lib/completions/llm.rb | 2 +- lib/completions/prompt.rb | 6 +- .../lib/completions/endpoints/open_ai_spec.rb | 89 +++++++++++++++++++ spec/lib/modules/ai_bot/playground_spec.rb | 34 +++++-- .../admin/ai_personas_controller_spec.rb | 6 +- spec/system/admin_ai_persona_spec.rb | 2 +- .../unit/models/ai-persona-test.js | 4 +- 17 files changed, 236 insertions(+), 38 deletions(-) diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index ee653d5bc..2f3563ac8 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -120,15 +120,12 @@ def ai_persona_params def permit_tools(tools) return [] if !tools.is_a?(Array) - tools.filter_map do |tool, options| + tools.filter_map do |tool, options, force_tool| break nil if !tool.is_a?(String) options&.permit! if options && options.is_a?(ActionController::Parameters) - if options - [tool, options] - else - tool - end + # this is simpler from a storage perspective, 1 way to store tools + [tool, options, !!force_tool] end end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 29d31e54a..54eb63102 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -136,17 +136,23 @@ def class_instance end options = {} + force_tool_use = [] + tools = self.tools.filter_map do |element| klass = nil - if element.is_a?(String) && element.start_with?("custom-") - custom_tool_id = element.split("-", 2).last.to_i + element = [element] if element.is_a?(String) + + inner_name, current_options, should_force_tool_use = + element.is_a?(Array) ? element : [element, nil] + + if inner_name.start_with?("custom-") + custom_tool_id = inner_name.split("-", 2).last.to_i if AiTool.exists?(id: custom_tool_id, enabled: true) klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id) end else - inner_name, current_options = element.is_a?(Array) ? element : [element, nil] inner_name = inner_name.gsub("Tool", "") inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name) @@ -155,9 +161,10 @@ def class_instance options[klass] = current_options if current_options rescue StandardError end - - klass end + + force_tool_use << klass if should_force_tool_use + klass end ai_persona_id = self.id @@ -177,6 +184,7 @@ def class_instance end define_method(:tools) { tools } + define_method(:force_tool_use) { force_tool_use } define_method(:options) { options } define_method(:temperature) { @ai_persona&.temperature } define_method(:top_p) { @ai_persona&.top_p } diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index c12c461cd..9344f5794 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -59,21 +59,25 @@ class ToolOption { export default class AiPersona extends RestModel { // this code is here to convert the wire schema to easier to work with object // on the wire we pass in/out tools as an Array. - // [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3] + // [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3] // So we rework this into a "tools" property and nested toolOptions init(properties) { + this.forcedTools = []; if (properties.tools) { properties.tools = properties.tools.map((tool) => { if (typeof tool === "string") { return tool; } else { - let [toolId, options] = tool; + let [toolId, options, force] = tool; for (let optionId in options) { if (!options.hasOwnProperty(optionId)) { continue; } this.getToolOption(toolId, optionId).value = options[optionId]; } + if (force) { + this.forcedTools.push(toolId); + } return toolId; } }); @@ -109,6 +113,8 @@ export default class AiPersona extends RestModel { if (typeof toolId !== "string") { toolId = toolId[0]; } + + let force = this.forcedTools.includes(toolId); if (this.toolOptions && this.toolOptions[toolId]) { let options = this.toolOptions[toolId]; let optionsWithValues = {}; @@ -119,9 +125,9 @@ export default class AiPersona extends RestModel { let option = options[optionId]; optionsWithValues[optionId] = option.value; } - toolsWithOptions.push([toolId, optionsWithValues]); + toolsWithOptions.push([toolId, optionsWithValues, force]); } else { - toolsWithOptions.push(toolId); + toolsWithOptions.push([toolId, {}, force]); } }); attrs.tools = toolsWithOptions; @@ -133,7 +139,6 @@ export default class AiPersona extends RestModel { : this.getProperties(CREATE_ATTRIBUTES); attrs.id = this.id; this.populateToolOptions(attrs); - return attrs; } @@ -146,6 +151,9 @@ export default class AiPersona extends RestModel { workingCopy() { let attrs = this.getProperties(CREATE_ATTRIBUTES); this.populateToolOptions(attrs); - return AiPersona.create(attrs); + + const persona = AiPersona.create(attrs); + persona.forcedTools = (this.forcedTools || []).slice(); + return persona; } } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 139768c14..574911eca 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -40,10 +40,39 @@ export default class PersonaEditor extends Component { @tracked maxPixelsValue = null; @tracked ragIndexingStatuses = null; + @tracked selectedTools = []; + @tracked selectedToolNames = []; + @tracked forcedToolNames = []; + get chatPluginEnabled() { return this.siteSettings.chat_enabled; } + get allowForceTools() { + return !this.editingModel?.system && this.editingModel?.tools?.length > 0; + } + + @action + forcedToolsChanged(tools) { + this.forcedToolNames = tools; + this.editingModel.forcedTools = this.forcedToolNames; + } + + @action + toolsChanged(tools) { + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + tools.includes(tool.id) + ); + this.selectedToolNames = tools.slice(); + + this.forcedToolNames = this.forcedToolNames.filter( + (tool) => this.editingModel.tools.indexOf(tool) !== -1 + ); + + this.editingModel.tools = this.selectedToolNames; + this.editingModel.forcedTools = this.forcedToolNames; + } + @action updateModel() { this.editingModel = this.args.model.workingCopy(); @@ -51,6 +80,12 @@ export default class PersonaEditor extends Component { this.maxPixelsValue = this.findClosestPixelValue( this.editingModel.vision_max_pixels ); + + this.selectedToolNames = this.editingModel.tools || []; + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + this.selectedToolNames.includes(tool.id) + ); + this.forcedToolNames = this.editingModel.forcedTools || []; } findClosestPixelValue(pixels) { @@ -336,15 +371,27 @@ export default class PersonaEditor extends Component { + {{#if this.allowForceTools}} +
+ + +
+ {{/if}} {{#unless this.editingModel.system}} {{/unless}} diff --git a/assets/javascripts/discourse/components/ai-tool-selector.js b/assets/javascripts/discourse/components/ai-tool-selector.js index 0060e06f0..c3959eff8 100644 --- a/assets/javascripts/discourse/components/ai-tool-selector.js +++ b/assets/javascripts/discourse/components/ai-tool-selector.js @@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({ this.selectKit.options.set("disabled", this.get("attrs.disabled.value")); }), - content: computed(function () { + content: computed("tools", function () { return this.tools; }), diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 7b6d7f359..f4719967d 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -148,6 +148,7 @@ en: saved: AI Persona Saved enabled: "Enabled?" tools: Enabled Tools + forced_tools: Forced Tools allowed_groups: Allowed Groups confirm_delete: Are you sure you want to delete this persona? new: "New Persona" diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index e3941e0e2..391ee16b4 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -67,6 +67,19 @@ def get_updated_title(conversation_context, post) .last end + def force_tool_if_needed(prompt, context) + context[:chosen_tools] ||= [] + forced_tools = persona.force_tool_use.map { |tool| tool.name } + force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } + + if force_tool + context[:chosen_tools] << force_tool + prompt.tool_choice = force_tool + else + prompt.tool_choice = nil + end + end + def reply(context, &update_blk) llm = DiscourseAi::Completions::Llm.proxy(model) prompt = persona.craft_prompt(context, llm: llm) @@ -85,6 +98,7 @@ def reply(context, &update_blk) while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false + force_tool_if_needed(prompt, context) result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index b46abebe3..0a31598c7 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -113,6 +113,10 @@ def tools [] end + def force_tool_use + [] + end + def required_tools [] end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 5420e6438..fa3a9ca4f 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -60,6 +60,10 @@ def tools @tools ||= tools_dialect.translated_tools end + def tool_choice + prompt.tool_choice + end + def translate messages = prompt.messages diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 43f190532..35b3e724c 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -54,8 +54,12 @@ def prepare_payload(prompt, model_params, dialect) # We'll fallback to guess this using the tokenizer. payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai" end - - payload[:tools] = dialect.tools if dialect.tools.present? + if dialect.tools.present? + payload[:tools] = dialect.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } } + end + end payload end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 527ec87fc..445bfc199 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -123,7 +123,7 @@ def with_prepared_responses(responses, llm: nil) end def record_prompt(prompt) - @prompts << prompt if @prompts + @prompts << prompt.dup if @prompts end def proxy(model) diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 051818d2d..9a6d4d617 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -6,7 +6,7 @@ class Prompt INVALID_TURN = Class.new(StandardError) attr_reader :messages - attr_accessor :tools, :topic_id, :post_id, :max_pixels + attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice def initialize( system_message_text = nil, @@ -14,7 +14,8 @@ def initialize( tools: [], topic_id: nil, post_id: nil, - max_pixels: nil + max_pixels: nil, + tool_choice: nil ) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) @@ -37,6 +38,7 @@ def initialize( @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } @tools = tools + @tool_choice = tool_choice end def push(type:, content:, id: nil, name: nil, upload_ids: nil) diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 8c6656391..44ff136cb 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -255,6 +255,95 @@ def request_body(prompt, stream: false, tool_call: false) end end + describe "forced tool use" do + it "can properly force tool use" do + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + + tools = [ + { + name: "echo", + description: "echo something", + parameters: [ + { name: "text", type: "string", description: "text to echo", required: true }, + ], + }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: tools, + tool_choice: "echo", + ) + + response = { + id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH", + object: "chat.completion", + created: 1_714_544_914, + model: "gpt-4-turbo-2024-04-09", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: nil, + tool_calls: [ + { + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + type: "function", + function: { + name: "echo", + arguments: "{\"text\":\"hello\"}", + }, + }, + ], + }, + logprobs: nil, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 55, + completion_tokens: 13, + total_tokens: 68, + }, + system_fingerprint: "fp_ea6eb70039", + }.to_json + + body_json = nil + stub_request(:post, "https://api.openai.com/v1/chat/completions").with( + body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) }, + ).to_return(body: response) + + result = llm.generate(prompt, user: user) + + expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) + + expected = (<<~TXT).strip + + + echo + + hello + + call_I8LKnoijVuhKOM85nnEQgWwd + + + TXT + + expect(result.strip).to eq(expected) + + stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + body: { choices: [message: { content: "OK" }] }.to_json, + ) + + result = llm.generate(prompt, user: user) + + expect(result).to eq("OK") + end + end + describe "image support" do it "can handle images" do model = Fabricate(:llm_model, vision_enabled: true) diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index f087414b0..c3f054df2 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -78,13 +78,7 @@ end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - - it "uses custom tool in conversation" do - persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } - bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) - playground = DiscourseAi::AiBot::Playground.new(bot) - - function_call = (<<~XML).strip + let(:function_call) { (<<~XML).strip } search @@ -96,6 +90,32 @@ ", XML + let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } + + let(:playground) { DiscourseAi::AiBot::Playground.new(bot) } + + it "can force usage of a tool" do + tool_name = "custom-#{custom_tool.id}" + ai_persona.update!(tools: [[tool_name, nil, "force"]]) + responses = [function_call, "custom tool did stuff (maybe)"] + + prompt = nil + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| + new_post = Fabricate(:post, raw: "Can you use the custom tool?") + _reply_post = playground.reply_to(new_post) + prompt = _prompt + end + + expect(prompt.length).to eq(2) + expect(prompt[0].tool_choice).to eq("search") + expect(prompt[1].tool_choice).to eq(nil) + end + + it "uses custom tool in conversation" do + persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } + bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) + playground = DiscourseAi::AiBot::Playground.new(bot) + responses = [function_call, "custom tool did stuff (maybe)"] reply_post = nil diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 203abe33a..1b1da5fc6 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -160,7 +160,7 @@ name: "superbot", description: "Assists with tasks", system_prompt: "you are a helpful bot", - tools: [["search", { "base_query" => "test" }]], + tools: [["search", { "base_query" => "test" }, true]], top_p: 0.1, temperature: 0.5, mentionable: true, @@ -186,7 +186,7 @@ persona = AiPersona.find(persona_json["id"]) - expect(persona.tools).to eq([["search", { "base_query" => "test" }]]) + expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]]) expect(persona.top_p).to eq(0.1) expect(persona.temperature).to eq(0.5) }.to change(AiPersona, :count).by(1) @@ -296,7 +296,7 @@ ai_persona.reload expect(ai_persona.name).to eq("SuperBot") expect(ai_persona.enabled).to eq(false) - expect(ai_persona.tools).to eq(["search"]) + expect(ai_persona.tools).to eq([["search", nil, false]]) end end diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb index 7ecd57ddf..171cfcc85 100644 --- a/spec/system/admin_ai_persona_spec.rb +++ b/spec/system/admin_ai_persona_spec.rb @@ -30,7 +30,7 @@ expect(persona.name).to eq("Test Persona") expect(persona.description).to eq("I am a test persona") expect(persona.system_prompt).to eq("You are a helpful bot") - expect(persona.tools).to eq([["Read", { "read_private" => nil }]]) + expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) end it "will not allow deletion or editing of system personas" do diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index f785ffbad..c1f25aeb4 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const updatedProperties = aiPersona.updateProperties(); // perform remapping for save - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(updatedProperties, properties); }); @@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const createdProperties = aiPersona.createProperties(); - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(createdProperties, properties); });