9
9
import static org .opensearch .ml .common .CommonValue .VERSION_3_1_0 ;
10
10
11
11
import java .io .IOException ;
12
+ import java .util .Locale ;
12
13
import java .util .Map ;
13
14
import java .util .Set ;
14
15
import java .util .stream .Collectors ;
@@ -41,20 +42,60 @@ public class BaseModelConfig extends MLModelConfig {
41
42
it -> parse (it )
42
43
);
43
44
45
+ public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension" ;
46
+ public static final String FRAMEWORK_TYPE_FIELD = "framework_type" ;
47
+ public static final String POOLING_MODE_FIELD = "pooling_mode" ;
48
+ public static final String NORMALIZE_RESULT_FIELD = "normalize_result" ;
49
+ public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length" ;
50
+ public static final String QUERY_PREFIX = "query_prefix" ;
51
+ public static final String PASSAGE_PREFIX = "passage_prefix" ;
44
52
public static final String ADDITIONAL_CONFIG_FIELD = "additional_config" ;
53
+
54
+ protected Integer embeddingDimension ;
55
+ protected FrameworkType frameworkType ;
56
+ protected PoolingMode poolingMode ;
57
+ protected boolean normalizeResult ;
58
+ protected Integer modelMaxLength ;
59
+ protected String queryPrefix ;
60
+ protected String passagePrefix ;
45
61
protected Map <String , Object > additionalConfig ;
46
62
47
63
@ Builder (builderMethodName = "baseModelConfigBuilder" )
48
- public BaseModelConfig (String modelType , String allConfig , Map <String , Object > additionalConfig ) {
64
+ public BaseModelConfig (
65
+ String modelType ,
66
+ String allConfig ,
67
+ Map <String , Object > additionalConfig ,
68
+ Integer embeddingDimension ,
69
+ FrameworkType frameworkType ,
70
+ PoolingMode poolingMode ,
71
+ boolean normalizeResult ,
72
+ Integer modelMaxLength ,
73
+ String queryPrefix ,
74
+ String passagePrefix
75
+ ) {
49
76
super (modelType , allConfig );
50
77
this .additionalConfig = additionalConfig ;
78
+ this .embeddingDimension = embeddingDimension ;
79
+ this .frameworkType = frameworkType ;
80
+ this .poolingMode = poolingMode ;
81
+ this .normalizeResult = normalizeResult ;
82
+ this .modelMaxLength = modelMaxLength ;
83
+ this .queryPrefix = queryPrefix ;
84
+ this .passagePrefix = passagePrefix ;
51
85
validateNoDuplicateKeys (allConfig , additionalConfig );
52
86
}
53
87
54
88
public static BaseModelConfig parse (XContentParser parser ) throws IOException {
55
89
String modelType = null ;
56
90
String allConfig = null ;
57
91
Map <String , Object > additionalConfig = null ;
92
+ Integer embeddingDimension = null ;
93
+ FrameworkType frameworkType = null ;
94
+ PoolingMode poolingMode = null ;
95
+ boolean normalizeResult = false ;
96
+ Integer modelMaxLength = null ;
97
+ String queryPrefix = null ;
98
+ String passagePrefix = null ;
58
99
59
100
ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
60
101
while (parser .nextToken () != XContentParser .Token .END_OBJECT ) {
@@ -71,12 +112,44 @@ public static BaseModelConfig parse(XContentParser parser) throws IOException {
71
112
case ADDITIONAL_CONFIG_FIELD :
72
113
additionalConfig = parser .map ();
73
114
break ;
115
+ case EMBEDDING_DIMENSION_FIELD :
116
+ embeddingDimension = parser .intValue ();
117
+ break ;
118
+ case FRAMEWORK_TYPE_FIELD :
119
+ frameworkType = FrameworkType .from (parser .text ().toUpperCase (Locale .ROOT ));
120
+ break ;
121
+ case POOLING_MODE_FIELD :
122
+ poolingMode = PoolingMode .from (parser .text ().toUpperCase (Locale .ROOT ));
123
+ break ;
124
+ case NORMALIZE_RESULT_FIELD :
125
+ normalizeResult = parser .booleanValue ();
126
+ break ;
127
+ case MODEL_MAX_LENGTH_FIELD :
128
+ modelMaxLength = parser .intValue ();
129
+ break ;
130
+ case QUERY_PREFIX :
131
+ queryPrefix = parser .text ();
132
+ break ;
133
+ case PASSAGE_PREFIX :
134
+ passagePrefix = parser .text ();
135
+ break ;
74
136
default :
75
137
parser .skipChildren ();
76
138
break ;
77
139
}
78
140
}
79
- return new BaseModelConfig (modelType , allConfig , additionalConfig );
141
+ return new BaseModelConfig (
142
+ modelType ,
143
+ allConfig ,
144
+ additionalConfig ,
145
+ embeddingDimension ,
146
+ frameworkType ,
147
+ poolingMode ,
148
+ normalizeResult ,
149
+ modelMaxLength ,
150
+ queryPrefix ,
151
+ passagePrefix
152
+ );
80
153
}
81
154
82
155
@ Override
@@ -89,6 +162,21 @@ public BaseModelConfig(StreamInput in) throws IOException {
89
162
if (in .getVersion ().onOrAfter (VERSION_3_1_0 )) {
90
163
this .additionalConfig = in .readMap ();
91
164
}
165
+ embeddingDimension = in .readOptionalInt ();
166
+ if (in .readBoolean ()) {
167
+ frameworkType = in .readEnum (FrameworkType .class );
168
+ } else {
169
+ frameworkType = null ;
170
+ }
171
+ if (in .readBoolean ()) {
172
+ poolingMode = in .readEnum (PoolingMode .class );
173
+ } else {
174
+ poolingMode = null ;
175
+ }
176
+ normalizeResult = in .readBoolean ();
177
+ modelMaxLength = in .readOptionalInt ();
178
+ queryPrefix = in .readOptionalString ();
179
+ passagePrefix = in .readOptionalString ();
92
180
}
93
181
94
182
@ Override
@@ -97,6 +185,23 @@ public void writeTo(StreamOutput out) throws IOException {
97
185
if (out .getVersion ().onOrAfter (VERSION_3_1_0 )) {
98
186
out .writeMap (additionalConfig );
99
187
}
188
+ out .writeOptionalInt (embeddingDimension );
189
+ if (frameworkType != null ) {
190
+ out .writeBoolean (true );
191
+ out .writeEnum (frameworkType );
192
+ } else {
193
+ out .writeBoolean (false );
194
+ }
195
+ if (poolingMode != null ) {
196
+ out .writeBoolean (true );
197
+ out .writeEnum (poolingMode );
198
+ } else {
199
+ out .writeBoolean (false );
200
+ }
201
+ out .writeBoolean (normalizeResult );
202
+ out .writeOptionalInt (modelMaxLength );
203
+ out .writeOptionalString (queryPrefix );
204
+ out .writeOptionalString (passagePrefix );
100
205
}
101
206
102
207
@ Override
@@ -111,10 +216,72 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
111
216
if (additionalConfig != null ) {
112
217
builder .field (ADDITIONAL_CONFIG_FIELD , additionalConfig );
113
218
}
219
+ if (embeddingDimension != null ) {
220
+ builder .field (EMBEDDING_DIMENSION_FIELD , embeddingDimension );
221
+ }
222
+ if (frameworkType != null ) {
223
+ builder .field (FRAMEWORK_TYPE_FIELD , frameworkType );
224
+ }
225
+ if (modelMaxLength != null ) {
226
+ builder .field (MODEL_MAX_LENGTH_FIELD , modelMaxLength );
227
+ }
228
+ if (poolingMode != null ) {
229
+ builder .field (POOLING_MODE_FIELD , poolingMode );
230
+ }
231
+ if (normalizeResult ) {
232
+ builder .field (NORMALIZE_RESULT_FIELD , normalizeResult );
233
+ }
234
+ if (queryPrefix != null ) {
235
+ builder .field (QUERY_PREFIX , queryPrefix );
236
+ }
237
+ if (passagePrefix != null ) {
238
+ builder .field (PASSAGE_PREFIX , passagePrefix );
239
+ }
114
240
builder .endObject ();
115
241
return builder ;
116
242
}
117
243
244
+ public enum PoolingMode {
245
+ MEAN ("mean" ),
246
+ MEAN_SQRT_LEN ("mean_sqrt_len" ),
247
+ MAX ("max" ),
248
+ WEIGHTED_MEAN ("weightedmean" ),
249
+ CLS ("cls" ),
250
+ LAST_TOKEN ("lasttoken" );
251
+
252
+ private String name ;
253
+
254
+ public String getName () {
255
+ return name ;
256
+ }
257
+
258
+ PoolingMode (String name ) {
259
+ this .name = name ;
260
+ }
261
+
262
+ public static PoolingMode from (String value ) {
263
+ try {
264
+ return PoolingMode .valueOf (value .toUpperCase (Locale .ROOT ));
265
+ } catch (Exception e ) {
266
+ throw new IllegalArgumentException ("Wrong pooling method" );
267
+ }
268
+ }
269
+ }
270
+
271
+ public enum FrameworkType {
272
+ HUGGINGFACE_TRANSFORMERS ,
273
+ SENTENCE_TRANSFORMERS ,
274
+ HUGGINGFACE_TRANSFORMERS_NEURON ;
275
+
276
+ public static FrameworkType from (String value ) {
277
+ try {
278
+ return FrameworkType .valueOf (value .toUpperCase (Locale .ROOT ));
279
+ } catch (Exception e ) {
280
+ throw new IllegalArgumentException ("Wrong framework type" );
281
+ }
282
+ }
283
+ }
284
+
118
285
protected void validateNoDuplicateKeys (String allConfig , Map <String , Object > additionalConfig ) {
119
286
if (allConfig == null || additionalConfig == null || additionalConfig .isEmpty ()) {
120
287
return ;
0 commit comments