8
8
import static org .opensearch .ml .common .settings .MLCommonsSettings .ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE ;
9
9
import static org .opensearch .ml .common .utils .StringUtils .gson ;
10
10
import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_QUERY ;
11
+ import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_SEARCH_TEMPLATE ;
11
12
import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_USER_PROMPT ;
13
+ import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .TEMPLATE_SELECTION_SYSTEM_PROMPT ;
14
+ import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .TEMPLATE_SELECTION_USER_PROMPT ;
12
15
16
+ import java .util .HashMap ;
13
17
import java .util .List ;
14
18
import java .util .Map ;
15
19
16
20
import org .apache .commons .text .StringSubstitutor ;
17
21
import org .opensearch .OpenSearchException ;
22
+ import org .opensearch .action .admin .cluster .storedscripts .GetStoredScriptRequest ;
18
23
import org .opensearch .core .action .ActionListener ;
19
24
import org .opensearch .ml .common .settings .MLFeatureEnabledSetting ;
20
25
import org .opensearch .ml .common .spi .tools .ToolAnnotation ;
@@ -43,10 +48,15 @@ public class QueryPlanningTool implements WithModelTool {
43
48
public static final String QUERY_FIELDS_FIELD = "query_fields" ;
44
49
private static final String GENERATION_TYPE_FIELD = "generation_type" ;
45
50
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated" ;
51
+ private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates" ;
52
+ private static final String SEARCH_TEMPLATES_FIELD = "search_templates" ;
53
+ public static final String TEMPLATE_FIELD = "template" ;
46
54
private static final String DEFAULT_SYSTEM_PROMPT =
47
55
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries" ;
48
56
@ Getter
49
57
private final String generationType ;
58
+ @ Getter
59
+ private final String searchTemplates ;
50
60
@ Setter
51
61
@ Getter
52
62
private String name = TYPE ;
@@ -57,10 +67,17 @@ public class QueryPlanningTool implements WithModelTool {
57
67
@ Getter
58
68
@ Setter
59
69
private String description = DEFAULT_DESCRIPTION ;
70
+ private final Client client ;
71
+
72
+ public QueryPlanningTool (String generationType , MLModelTool queryGenerationTool , Client client ) {
73
+ this (generationType , queryGenerationTool , client , null );
74
+ }
60
75
61
- public QueryPlanningTool (String generationType , MLModelTool queryGenerationTool ) {
76
+ public QueryPlanningTool (String generationType , MLModelTool queryGenerationTool , Client client , String searchTemplates ) {
62
77
this .generationType = generationType ;
63
78
this .queryGenerationTool = queryGenerationTool ;
79
+ this .client = client ;
80
+ this .searchTemplates = searchTemplates ;
64
81
}
65
82
66
83
@ Override
@@ -70,6 +87,48 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
70
87
listener .onFailure (new IllegalArgumentException ("Empty parameters for QueryPlanningTool: " + parameters ));
71
88
return ;
72
89
}
90
+
91
+ if (!generationType .equals (USER_SEARCH_TEMPLATES_TYPE_FIELD )) {
92
+ // Use default search template, skip template selection
93
+ parameters .put (TEMPLATE_FIELD , DEFAULT_SEARCH_TEMPLATE );
94
+ executeQueryPlanning (parameters , listener );
95
+ return ;
96
+ }
97
+
98
+ // Template Selection, replace user and system prompts
99
+ Map <String , String > templateSelectionParameters = new HashMap <>(parameters );
100
+ templateSelectionParameters .put (SYSTEM_PROMPT_FIELD , TEMPLATE_SELECTION_SYSTEM_PROMPT );
101
+ templateSelectionParameters .put (USER_PROMPT_FIELD , TEMPLATE_SELECTION_USER_PROMPT );
102
+ templateSelectionParameters .put (SEARCH_TEMPLATES_FIELD , searchTemplates );
103
+
104
+ ActionListener <T > templateSelectionListener = ActionListener .wrap (r -> {
105
+ try {
106
+ String templateId = (String ) r ;
107
+ if (templateId == null || templateId .isBlank () || templateId .equals ("null" )) {
108
+ // Default search template if LLM does not choose
109
+ parameters .put (TEMPLATE_FIELD , DEFAULT_SEARCH_TEMPLATE );
110
+ executeQueryPlanning (parameters , listener );
111
+ } else {
112
+ // Retrieve search template by ID
113
+ GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest (templateId );
114
+ client .admin ().cluster ().getStoredScript (getStoredScriptRequest , ActionListener .wrap (getStoredScriptResponse -> {
115
+ parameters .put (TEMPLATE_FIELD , gson .toJson (getStoredScriptResponse .getSource ().getSource ()));
116
+ executeQueryPlanning (parameters , listener );
117
+ }, e -> { listener .onFailure (e ); }));
118
+ }
119
+ } catch (Exception e ) {
120
+ IllegalArgumentException parsingException = new IllegalArgumentException (
121
+ "Error processing search template: " + r + ". Try using response_filter in agent registration if needed." ,
122
+ e
123
+ );
124
+ listener .onFailure (parsingException );
125
+ }
126
+ }, listener ::onFailure );
127
+ queryGenerationTool .run (templateSelectionParameters , templateSelectionListener );
128
+ }
129
+
130
+ private <T > void executeQueryPlanning (Map <String , String > parameters , ActionListener <T > listener ) {
131
+ // Execute Query Planning, replace System and User prompt fields
73
132
if (!parameters .containsKey (SYSTEM_PROMPT_FIELD )) {
74
133
parameters .put (SYSTEM_PROMPT_FIELD , DEFAULT_SYSTEM_PROMPT );
75
134
}
@@ -154,16 +213,32 @@ public QueryPlanningTool create(Map<String, Object> map) {
154
213
MLModelTool queryGenerationTool = MLModelTool .Factory .getInstance ().create (map );
155
214
156
215
String type = (String ) map .get (GENERATION_TYPE_FIELD );
216
+
217
+ // defaulted to llmGenerated
157
218
if (type == null || type .isEmpty ()) {
158
219
type = LLM_GENERATED_TYPE_FIELD ;
159
220
}
160
221
161
- // TODO to add in SYSTEM_SEARCH_TEMPLATES_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD when searchTemplatesTool is
162
- // implemented.
163
- if (!LLM_GENERATED_TYPE_FIELD .equals (type )) {
164
- throw new IllegalArgumentException ("Invalid generation type: " + type + ". The current supported types are llmGenerated." );
222
+ // type validation
223
+ if (!(LLM_GENERATED_TYPE_FIELD .equals (type ) || USER_SEARCH_TEMPLATES_TYPE_FIELD .equals (type ))) {
224
+ throw new IllegalArgumentException (
225
+ "Invalid generation type: " + type + ". The current supported types are llmGenerated and user_templates."
226
+ );
165
227
}
166
- return new QueryPlanningTool (type , queryGenerationTool );
228
+
229
+ // Parse search templates if generation type is user_templates
230
+ String searchTemplates = null ;
231
+ if (USER_SEARCH_TEMPLATES_TYPE_FIELD .equals (type )) {
232
+ if (!map .containsKey (SEARCH_TEMPLATES_FIELD )) {
233
+ throw new IllegalArgumentException ("search_templates field is required when generation_type is 'user_templates'" );
234
+ } else {
235
+ // array is parsed as a json string
236
+ searchTemplates = gson .toJson ((String ) map .get (SEARCH_TEMPLATES_FIELD ));
237
+
238
+ }
239
+ }
240
+
241
+ return new QueryPlanningTool (type , queryGenerationTool , client , searchTemplates );
167
242
}
168
243
169
244
@ Override
0 commit comments