Skip to content

Commit e3baea9

Browse files
authored
Search Template Support for QueryPlanningTool (#4154)
Signed-off-by: Joshua Palis <[email protected]>
1 parent dc8403f commit e3baea9

File tree

3 files changed

+381
-18
lines changed

3 files changed

+381
-18
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ public class QueryPlanningPromptTemplate {
155155
+ "Query Fields: ${parameters.query_fields:-}\n\n"
156156
+ "==== OUTPUT ====\n"
157157
+ "GIVE THE OUTPUT PART ONLY IN YOUR RESPONSE (a single JSON object)\n"
158+
+ "Use this template provided by the user as reference to generate the query: ${parameters.template}\n\n"
158159
+ "Output:";
159160

160161
public static final String DEFAULT_USER_PROMPT = PROMPT_PREFIX
@@ -164,4 +165,105 @@ public class QueryPlanningPromptTemplate {
164165
+ EXAMPLES
165166
+ "\n\n"
166167
+ PROMPT_SUFFIX;
168+
169+
// Template selection prompt
170+
171+
public static final String TEMPLATE_SELECTION_SYSTEM_PROMPT = "==== PURPOSE ====\n"
172+
+ "You are an OpenSearch Search Template selector. Given a natural language question, a list of search template IDs and search template descriptions, choose the search template ID which is most related to the given question.\n\n";
173+
174+
public static final String TEMPLATE_SELECTION_GOAL = "Given:\n"
175+
+ "1) A natural-language question from the user.\n"
176+
+ "2) A catalog of OpenSearch templates, each with:\n"
177+
+ " - id (string, case-sensitive)\n"
178+
+ " - description (1–3 sentences)\n"
179+
+ "Return: the SINGLE id of the best-matching template.";
180+
181+
public static final String TEMPLATE_SELECTION_OUTPUT_RULES = "- Output ONLY the template id.\n"
182+
+ "- No quotes, no backticks, no punctuation, no prefix/suffix, no extra words.\n"
183+
+ "- No spaces or newlines before/after. Output must be exactly one of the provided ids.\n"
184+
+ "- Do not ask questions or explain.\n"
185+
+ "- Think internally; do NOT reveal your reasoning.";
186+
187+
public static final String TEMPLATE_SELECTION_CRITERIA = "(apply in order)\n"
188+
+ "1) INTENT MATCH: Identify the user’s primary intent (e.g., product search/browse, analytical reporting, trend/sales analysis, inventory, support lookup). Prefer templates whose descriptions explicitly support that intent.\n"
189+
+ "2) SIGNAL ALIGNMENT: Count strong lexical/semantic matches between the question and each template’s description/placeholders.\n"
190+
+ " - Attribute filters (brand, category, size, color, price, rating, etc.) → favor product/item search templates.\n"
191+
+ " - Metrics (sales value, revenue, units sold, conversion, time windows) → favor analytics/aggregation templates.\n"
192+
+ " - Temporal phrases (“last week”, “by month”, “trending”, “top sellers”) → favor templates with date/time and aggregations.\n"
193+
+ " - Opinion/quality words (“highly rated”, “best”, “top reviewed”) → favor templates with rating/review placeholders.\n"
194+
+ "3) SPECIFICITY: If multiple templates match, prefer the one whose description/placeholders are the most specific to the question’s entities and constraints.\n"
195+
+ "4) TIE-BREAK:\n"
196+
+ " - Prefer templates intended for the user’s domain (e.g., “products” vs “sales analytics”).\n"
197+
+ " - Prefer general-purpose search over analytics if the question asks to “find/search/browse” items; prefer analytics if it asks for “most sold/revenue/total/average”.";
198+
199+
public static final String TEMPLATE_SELECTION_VALIDATION =
200+
"- Your output MUST be exactly one of the provided template ids (regex: ^[A-Za-z0-9_-]+$).\n"
201+
+ "- If no perfect match exists, pick the closest by the criteria above. Never output “none” or invent an id.";
202+
203+
public static final String TEMPLATE_SELECTION_INPUTS = "question: ${parameters.query_text}\n"
204+
+ "templates: ${parameters.search_templates}";
205+
206+
public static final String TEMPLATE_SELECTION_EXAMPLES = "Example A: \n"
207+
+ "question: 'what shoes are highly rated'\n"
208+
+ "templates:\n"
209+
+ "[\n"
210+
+ "{'id':'product-search-template','description':'Searches products in an e-commerce store.'},\n"
211+
+ "{'id':'sales-value-analysis-template','description':'Aggregates sales value for top-selling products.'}\n"
212+
+ "]\n"
213+
+ "Example output : 'product-search-template'";
214+
215+
public static final String TEMPLATE_SELECTION_USER_PROMPT = "==== GOAL ====\n"
216+
+ TEMPLATE_SELECTION_GOAL
217+
+ "\n"
218+
+ "==== OUTPUT RULES ====\n"
219+
+ TEMPLATE_SELECTION_OUTPUT_RULES
220+
+ "\n"
221+
+ "==== SELECTION CRITERIA ====\n"
222+
+ TEMPLATE_SELECTION_CRITERIA
223+
+ "\n"
224+
+ "==== VALIDATION ====\n"
225+
+ TEMPLATE_SELECTION_VALIDATION
226+
+ "\n"
227+
+ "==== EXAMPLES ====\n"
228+
+ TEMPLATE_SELECTION_EXAMPLES
229+
+ "\n"
230+
+ "==== INPUTS ====\n"
231+
+ TEMPLATE_SELECTION_INPUTS;
232+
233+
public static final String DEFAULT_SEARCH_TEMPLATE = "{"
234+
+ "\"from\": {{from}}{{^from}}0{{/from}},"
235+
+ "\"size\": {{size}}{{^size}}10{{/size}},"
236+
+ "\n"
237+
+ "\"query\": {"
238+
+ " \"bool\": {"
239+
+ " \"should\": ["
240+
+ " {"
241+
+ " \"multi_match\": {"
242+
+ " \"query\": \"{{lex_query}}\","
243+
+ " \"fields\": {{#lex_fields}}{{{lex_fields}}}{{/lex_fields}}{{^lex_fields}}[\"*^1.0\"]{{/lex_fields}},"
244+
+ " \"type\": \"{{#lex_type}}{{lex_type}}{{/lex_type}}{{^lex_type}}best_fields{{/lex_type}}\","
245+
+ " \"operator\": \"{{#lex_operator}}{{lex_operator}}{{/lex_operator}}{{^lex_operator}}or{{/lex_operator}}\","
246+
+ " \"boost\": {{#lex_boost}}{{lex_boost}}{{/lex_boost}}{{^lex_boost}}1.0{{/lex_boost}}"
247+
+ " }"
248+
+ " }{{#sem_enabled}},"
249+
+ " {"
250+
+ " \"neural\": {"
251+
+ " \"{{sem_field}}\": {"
252+
+ " \"query_text\": \"{{sem_query_text}}\","
253+
+ " \"model_id\": \"{{sem_model_id}}\","
254+
+ " \"k\": {{#sem_k}}{{sem_k}}{{/sem_k}}{{^sem_k}}150{{/sem_k}},"
255+
+ " \"boost\": {{#sem_boost}}{{sem_boost}}{{/sem_boost}}{{^sem_boost}}1.5{{/sem_boost}}"
256+
+ " }"
257+
+ " }"
258+
+ " }{{/sem_enabled}}"
259+
+ " ],"
260+
+ " \"filter\": {{#filters}}{{{filters}}}{{/filters}}{{^filters}}[]{{/filters}},"
261+
+ " \"minimum_should_match\": 1"
262+
+ " }"
263+
+ "},"
264+
+ "\n"
265+
+ "\"sort\": {{#sort}}{{{sort}}}{{/sort}}{{^sort}}[{ \"_score\": { \"order\": \"desc\" } }]{{/sort}},"
266+
+ "\n"
267+
+ "\"track_total_hits\": {{#track_total_hits}}{{track_total_hits}}{{/track_total_hits}}{{^track_total_hits}}false{{/track_total_hits}}"
268+
+ "}";
167269
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
99
import static org.opensearch.ml.common.utils.StringUtils.gson;
1010
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
11+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SEARCH_TEMPLATE;
1112
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_USER_PROMPT;
13+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.TEMPLATE_SELECTION_SYSTEM_PROMPT;
14+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.TEMPLATE_SELECTION_USER_PROMPT;
1215

16+
import java.util.HashMap;
1317
import java.util.List;
1418
import java.util.Map;
1519

1620
import org.apache.commons.text.StringSubstitutor;
1721
import org.opensearch.OpenSearchException;
22+
import org.opensearch.action.admin.cluster.storedscripts.GetStoredScriptRequest;
1823
import org.opensearch.core.action.ActionListener;
1924
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
2025
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
@@ -43,10 +48,15 @@ public class QueryPlanningTool implements WithModelTool {
4348
public static final String QUERY_FIELDS_FIELD = "query_fields";
4449
private static final String GENERATION_TYPE_FIELD = "generation_type";
4550
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
51+
private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
52+
private static final String SEARCH_TEMPLATES_FIELD = "search_templates";
53+
public static final String TEMPLATE_FIELD = "template";
4654
private static final String DEFAULT_SYSTEM_PROMPT =
4755
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries";
4856
@Getter
4957
private final String generationType;
58+
@Getter
59+
private final String searchTemplates;
5060
@Setter
5161
@Getter
5262
private String name = TYPE;
@@ -57,10 +67,17 @@ public class QueryPlanningTool implements WithModelTool {
5767
@Getter
5868
@Setter
5969
private String description = DEFAULT_DESCRIPTION;
70+
private final Client client;
71+
72+
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool, Client client) {
73+
this(generationType, queryGenerationTool, client, null);
74+
}
6075

61-
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool) {
76+
public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool, Client client, String searchTemplates) {
6277
this.generationType = generationType;
6378
this.queryGenerationTool = queryGenerationTool;
79+
this.client = client;
80+
this.searchTemplates = searchTemplates;
6481
}
6582

6683
@Override
@@ -70,6 +87,48 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
7087
listener.onFailure(new IllegalArgumentException("Empty parameters for QueryPlanningTool: " + parameters));
7188
return;
7289
}
90+
91+
if (!generationType.equals(USER_SEARCH_TEMPLATES_TYPE_FIELD)) {
92+
// Use default search template, skip template selection
93+
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
94+
executeQueryPlanning(parameters, listener);
95+
return;
96+
}
97+
98+
// Template Selection, replace user and system prompts
99+
Map<String, String> templateSelectionParameters = new HashMap<>(parameters);
100+
templateSelectionParameters.put(SYSTEM_PROMPT_FIELD, TEMPLATE_SELECTION_SYSTEM_PROMPT);
101+
templateSelectionParameters.put(USER_PROMPT_FIELD, TEMPLATE_SELECTION_USER_PROMPT);
102+
templateSelectionParameters.put(SEARCH_TEMPLATES_FIELD, searchTemplates);
103+
104+
ActionListener<T> templateSelectionListener = ActionListener.wrap(r -> {
105+
try {
106+
String templateId = (String) r;
107+
if (templateId == null || templateId.isBlank() || templateId.equals("null")) {
108+
// Default search template if LLM does not choose
109+
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
110+
executeQueryPlanning(parameters, listener);
111+
} else {
112+
// Retrieve search template by ID
113+
GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId);
114+
client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> {
115+
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
116+
executeQueryPlanning(parameters, listener);
117+
}, e -> { listener.onFailure(e); }));
118+
}
119+
} catch (Exception e) {
120+
IllegalArgumentException parsingException = new IllegalArgumentException(
121+
"Error processing search template: " + r + ". Try using response_filter in agent registration if needed.",
122+
e
123+
);
124+
listener.onFailure(parsingException);
125+
}
126+
}, listener::onFailure);
127+
queryGenerationTool.run(templateSelectionParameters, templateSelectionListener);
128+
}
129+
130+
private <T> void executeQueryPlanning(Map<String, String> parameters, ActionListener<T> listener) {
131+
// Execute Query Planning, replace System and User prompt fields
73132
if (!parameters.containsKey(SYSTEM_PROMPT_FIELD)) {
74133
parameters.put(SYSTEM_PROMPT_FIELD, DEFAULT_SYSTEM_PROMPT);
75134
}
@@ -154,16 +213,32 @@ public QueryPlanningTool create(Map<String, Object> map) {
154213
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
155214

156215
String type = (String) map.get(GENERATION_TYPE_FIELD);
216+
217+
// defaulted to llmGenerated
157218
if (type == null || type.isEmpty()) {
158219
type = LLM_GENERATED_TYPE_FIELD;
159220
}
160221

161-
// TODO to add in SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
162-
// implemented.
163-
if (!LLM_GENERATED_TYPE_FIELD.equals(type)) {
164-
throw new IllegalArgumentException("Invalid generation type: " + type + ". The current supported types are llmGenerated.");
222+
// type validation
223+
if (!(LLM_GENERATED_TYPE_FIELD.equals(type) || USER_SEARCH_TEMPLATES_TYPE_FIELD.equals(type))) {
224+
throw new IllegalArgumentException(
225+
"Invalid generation type: " + type + ". The current supported types are llmGenerated and user_templates."
226+
);
165227
}
166-
return new QueryPlanningTool(type, queryGenerationTool);
228+
229+
// Parse search templates if generation type is user_templates
230+
String searchTemplates = null;
231+
if (USER_SEARCH_TEMPLATES_TYPE_FIELD.equals(type)) {
232+
if (!map.containsKey(SEARCH_TEMPLATES_FIELD)) {
233+
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
234+
} else {
235+
// array is parsed as a json string
236+
searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD));
237+
238+
}
239+
}
240+
241+
return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
167242
}
168243

169244
@Override

0 commit comments

Comments
 (0)