Skip to content

Commit

Permalink
FEATURE: allows forced LLM tool use (#818)
Browse files Browse the repository at this point in the history
* FEATURE: allows forced LLM tool use

Sometimes we need to force LLMs to use tools, for example in RAG
like use cases we may want to force an unconditional search.

The new framework allows you backend to force tool usage.

Front end commit to follow

* UI for forcing tools now works, but it does not react right

* fix bugs

* fix tests, this is now ready for review
  • Loading branch information
SamSaffron authored Oct 4, 2024
1 parent c294b6d commit 545500b
Show file tree
Hide file tree
Showing 17 changed files with 236 additions and 38 deletions.
9 changes: 3 additions & 6 deletions app/controllers/discourse_ai/admin/ai_personas_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,12 @@ def ai_persona_params
def permit_tools(tools)
return [] if !tools.is_a?(Array)

tools.filter_map do |tool, options|
tools.filter_map do |tool, options, force_tool|
break nil if !tool.is_a?(String)
options&.permit! if options && options.is_a?(ActionController::Parameters)

if options
[tool, options]
else
tool
end
# this is simpler from a storage perspective, 1 way to store tools
[tool, options, !!force_tool]
end
end
end
Expand Down
18 changes: 13 additions & 5 deletions app/models/ai_persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,23 @@ def class_instance
end

options = {}
force_tool_use = []

tools =
self.tools.filter_map do |element|
klass = nil

if element.is_a?(String) && element.start_with?("custom-")
custom_tool_id = element.split("-", 2).last.to_i
element = [element] if element.is_a?(String)

inner_name, current_options, should_force_tool_use =
element.is_a?(Array) ? element : [element, nil]

if inner_name.start_with?("custom-")
custom_tool_id = inner_name.split("-", 2).last.to_i
if AiTool.exists?(id: custom_tool_id, enabled: true)
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
end
else
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
inner_name = inner_name.gsub("Tool", "")
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)

Expand All @@ -155,9 +161,10 @@ def class_instance
options[klass] = current_options if current_options
rescue StandardError
end

klass
end

force_tool_use << klass if should_force_tool_use
klass
end

ai_persona_id = self.id
Expand All @@ -177,6 +184,7 @@ def class_instance
end

define_method(:tools) { tools }
define_method(:force_tool_use) { force_tool_use }
define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p }
Expand Down
20 changes: 14 additions & 6 deletions assets/javascripts/discourse/admin/models/ai-persona.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,25 @@ class ToolOption {
export default class AiPersona extends RestModel {
// this code is here to convert the wire schema to easier to work with object
// on the wire we pass in/out tools as an Array.
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
// [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
// So we rework this into a "tools" property and nested toolOptions
init(properties) {
this.forcedTools = [];
if (properties.tools) {
properties.tools = properties.tools.map((tool) => {
if (typeof tool === "string") {
return tool;
} else {
let [toolId, options] = tool;
let [toolId, options, force] = tool;
for (let optionId in options) {
if (!options.hasOwnProperty(optionId)) {
continue;
}
this.getToolOption(toolId, optionId).value = options[optionId];
}
if (force) {
this.forcedTools.push(toolId);
}
return toolId;
}
});
Expand Down Expand Up @@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
if (typeof toolId !== "string") {
toolId = toolId[0];
}

let force = this.forcedTools.includes(toolId);
if (this.toolOptions && this.toolOptions[toolId]) {
let options = this.toolOptions[toolId];
let optionsWithValues = {};
Expand All @@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
let option = options[optionId];
optionsWithValues[optionId] = option.value;
}
toolsWithOptions.push([toolId, optionsWithValues]);
toolsWithOptions.push([toolId, optionsWithValues, force]);
} else {
toolsWithOptions.push(toolId);
toolsWithOptions.push([toolId, {}, force]);
}
});
attrs.tools = toolsWithOptions;
Expand All @@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
: this.getProperties(CREATE_ATTRIBUTES);
attrs.id = this.id;
this.populateToolOptions(attrs);

return attrs;
}

Expand All @@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
workingCopy() {
let attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateToolOptions(attrs);
return AiPersona.create(attrs);

const persona = AiPersona.create(attrs);
persona.forcedTools = (this.forcedTools || []).slice();
return persona;
}
}
51 changes: 49 additions & 2 deletions assets/javascripts/discourse/components/ai-persona-editor.gjs
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,52 @@ export default class PersonaEditor extends Component {
@tracked maxPixelsValue = null;
@tracked ragIndexingStatuses = null;

@tracked selectedTools = [];
@tracked selectedToolNames = [];
@tracked forcedToolNames = [];

get chatPluginEnabled() {
return this.siteSettings.chat_enabled;
}

get allowForceTools() {
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
}

@action
forcedToolsChanged(tools) {
this.forcedToolNames = tools;
this.editingModel.forcedTools = this.forcedToolNames;
}

@action
toolsChanged(tools) {
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
tools.includes(tool.id)
);
this.selectedToolNames = tools.slice();

this.forcedToolNames = this.forcedToolNames.filter(
(tool) => this.editingModel.tools.indexOf(tool) !== -1
);

this.editingModel.tools = this.selectedToolNames;
this.editingModel.forcedTools = this.forcedToolNames;
}

@action
updateModel() {
this.editingModel = this.args.model.workingCopy();
this.showDelete = !this.args.model.isNew && !this.args.model.system;
this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels
);

this.selectedToolNames = this.editingModel.tools || [];
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
this.selectedToolNames.includes(tool.id)
);
this.forcedToolNames = this.editingModel.forcedTools || [];
}

findClosestPixelValue(pixels) {
Expand Down Expand Up @@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
@value={{this.editingModel.tools}}
@value={{this.selectedToolNames}}
@disabled={{this.editingModel.system}}
@tools={{@personas.resultSetMeta.tools}}
@onChange={{this.toolsChanged}}
/>
</div>
{{#if this.allowForceTools}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
@value={{this.forcedToolNames}}
@tools={{this.selectedTools}}
@onChange={{this.forcedToolsChanged}}
/>
</div>
{{/if}}
{{#unless this.editingModel.system}}
<AiPersonaToolOptions
@persona={{this.editingModel}}
@tools={{this.editingModel.tools}}
@tools={{this.selectedToolNames}}
@allTools={{@personas.resultSetMeta.tools}}
/>
{{/unless}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
}),

content: computed(function () {
content: computed("tools", function () {
return this.tools;
}),

Expand Down
1 change: 1 addition & 0 deletions config/locales/client.en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ en:
saved: AI Persona Saved
enabled: "Enabled?"
tools: Enabled Tools
forced_tools: Forced Tools
allowed_groups: Allowed Groups
confirm_delete: Are you sure you want to delete this persona?
new: "New Persona"
Expand Down
14 changes: 14 additions & 0 deletions lib/ai_bot/bot.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def get_updated_title(conversation_context, post)
.last
end

def force_tool_if_needed(prompt, context)
context[:chosen_tools] ||= []
forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }

if force_tool
context[:chosen_tools] << force_tool
prompt.tool_choice = force_tool
else
prompt.tool_choice = nil
end
end

def reply(context, &update_blk)
llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm)
Expand All @@ -85,6 +98,7 @@ def reply(context, &update_blk)

while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
force_tool_if_needed(prompt, context)

result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
Expand Down
4 changes: 4 additions & 0 deletions lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def tools
[]
end

def force_tool_use
[]
end

def required_tools
[]
end
Expand Down
4 changes: 4 additions & 0 deletions lib/completions/dialects/dialect.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def tools
@tools ||= tools_dialect.translated_tools
end

def tool_choice
prompt.tool_choice
end

def translate
messages = prompt.messages

Expand Down
8 changes: 6 additions & 2 deletions lib/completions/endpoints/open_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ def prepare_payload(prompt, model_params, dialect)
# We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end

payload[:tools] = dialect.tools if dialect.tools.present?
if dialect.tools.present?
payload[:tools] = dialect.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
end
end
payload
end

Expand Down
2 changes: 1 addition & 1 deletion lib/completions/llm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def with_prepared_responses(responses, llm: nil)
end

def record_prompt(prompt)
@prompts << prompt if @prompts
@prompts << prompt.dup if @prompts
end

def proxy(model)
Expand Down
6 changes: 4 additions & 2 deletions lib/completions/prompt.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ class Prompt
INVALID_TURN = Class.new(StandardError)

attr_reader :messages
attr_accessor :tools, :topic_id, :post_id, :max_pixels
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice

def initialize(
system_message_text = nil,
messages: [],
tools: [],
topic_id: nil,
post_id: nil,
max_pixels: nil
max_pixels: nil,
tool_choice: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
Expand All @@ -37,6 +38,7 @@ def initialize(
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }

@tools = tools
@tool_choice = tool_choice
end

def push(type:, content:, id: nil, name: nil, upload_ids: nil)
Expand Down
Loading

0 comments on commit 545500b

Please sign in to comment.