Skip to content

Commit c5229ca

Browse files
Fix update model config invalid error (#3994)
* Fix update model config Signed-off-by: Nathalie Jonathan <[email protected]> * Add model config validation logic Signed-off-by: Nathalie Jonathan <[email protected]> * Simplified validateModelConfig, add version compatibility tests Signed-off-by: Nathalie Jonathan <[email protected]> --------- Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 93bc9a3 commit c5229ca

File tree

15 files changed

+419
-208
lines changed

15 files changed

+419
-208
lines changed

common/src/main/java/org/opensearch/ml/common/model/BaseModelConfig.java

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static org.opensearch.ml.common.CommonValue.VERSION_3_1_0;
1010

1111
import java.io.IOException;
12+
import java.util.Locale;
1213
import java.util.Map;
1314
import java.util.Set;
1415
import java.util.stream.Collectors;
@@ -41,20 +42,60 @@ public class BaseModelConfig extends MLModelConfig {
4142
it -> parse(it)
4243
);
4344

45+
public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
46+
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
47+
public static final String POOLING_MODE_FIELD = "pooling_mode";
48+
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
49+
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
50+
public static final String QUERY_PREFIX = "query_prefix";
51+
public static final String PASSAGE_PREFIX = "passage_prefix";
4452
public static final String ADDITIONAL_CONFIG_FIELD = "additional_config";
53+
54+
protected Integer embeddingDimension;
55+
protected FrameworkType frameworkType;
56+
protected PoolingMode poolingMode;
57+
protected boolean normalizeResult;
58+
protected Integer modelMaxLength;
59+
protected String queryPrefix;
60+
protected String passagePrefix;
4561
protected Map<String, Object> additionalConfig;
4662

4763
@Builder(builderMethodName = "baseModelConfigBuilder")
48-
public BaseModelConfig(String modelType, String allConfig, Map<String, Object> additionalConfig) {
64+
public BaseModelConfig(
65+
String modelType,
66+
String allConfig,
67+
Map<String, Object> additionalConfig,
68+
Integer embeddingDimension,
69+
FrameworkType frameworkType,
70+
PoolingMode poolingMode,
71+
boolean normalizeResult,
72+
Integer modelMaxLength,
73+
String queryPrefix,
74+
String passagePrefix
75+
) {
4976
super(modelType, allConfig);
5077
this.additionalConfig = additionalConfig;
78+
this.embeddingDimension = embeddingDimension;
79+
this.frameworkType = frameworkType;
80+
this.poolingMode = poolingMode;
81+
this.normalizeResult = normalizeResult;
82+
this.modelMaxLength = modelMaxLength;
83+
this.queryPrefix = queryPrefix;
84+
this.passagePrefix = passagePrefix;
5185
validateNoDuplicateKeys(allConfig, additionalConfig);
5286
}
5387

5488
public static BaseModelConfig parse(XContentParser parser) throws IOException {
5589
String modelType = null;
5690
String allConfig = null;
5791
Map<String, Object> additionalConfig = null;
92+
Integer embeddingDimension = null;
93+
FrameworkType frameworkType = null;
94+
PoolingMode poolingMode = null;
95+
boolean normalizeResult = false;
96+
Integer modelMaxLength = null;
97+
String queryPrefix = null;
98+
String passagePrefix = null;
5899

59100
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
60101
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -71,12 +112,44 @@ public static BaseModelConfig parse(XContentParser parser) throws IOException {
71112
case ADDITIONAL_CONFIG_FIELD:
72113
additionalConfig = parser.map();
73114
break;
115+
case EMBEDDING_DIMENSION_FIELD:
116+
embeddingDimension = parser.intValue();
117+
break;
118+
case FRAMEWORK_TYPE_FIELD:
119+
frameworkType = FrameworkType.from(parser.text().toUpperCase(Locale.ROOT));
120+
break;
121+
case POOLING_MODE_FIELD:
122+
poolingMode = PoolingMode.from(parser.text().toUpperCase(Locale.ROOT));
123+
break;
124+
case NORMALIZE_RESULT_FIELD:
125+
normalizeResult = parser.booleanValue();
126+
break;
127+
case MODEL_MAX_LENGTH_FIELD:
128+
modelMaxLength = parser.intValue();
129+
break;
130+
case QUERY_PREFIX:
131+
queryPrefix = parser.text();
132+
break;
133+
case PASSAGE_PREFIX:
134+
passagePrefix = parser.text();
135+
break;
74136
default:
75137
parser.skipChildren();
76138
break;
77139
}
78140
}
79-
return new BaseModelConfig(modelType, allConfig, additionalConfig);
141+
return new BaseModelConfig(
142+
modelType,
143+
allConfig,
144+
additionalConfig,
145+
embeddingDimension,
146+
frameworkType,
147+
poolingMode,
148+
normalizeResult,
149+
modelMaxLength,
150+
queryPrefix,
151+
passagePrefix
152+
);
80153
}
81154

82155
@Override
@@ -89,6 +162,21 @@ public BaseModelConfig(StreamInput in) throws IOException {
89162
if (in.getVersion().onOrAfter(VERSION_3_1_0)) {
90163
this.additionalConfig = in.readMap();
91164
}
165+
embeddingDimension = in.readOptionalInt();
166+
if (in.readBoolean()) {
167+
frameworkType = in.readEnum(FrameworkType.class);
168+
} else {
169+
frameworkType = null;
170+
}
171+
if (in.readBoolean()) {
172+
poolingMode = in.readEnum(PoolingMode.class);
173+
} else {
174+
poolingMode = null;
175+
}
176+
normalizeResult = in.readBoolean();
177+
modelMaxLength = in.readOptionalInt();
178+
queryPrefix = in.readOptionalString();
179+
passagePrefix = in.readOptionalString();
92180
}
93181

94182
@Override
@@ -97,6 +185,23 @@ public void writeTo(StreamOutput out) throws IOException {
97185
if (out.getVersion().onOrAfter(VERSION_3_1_0)) {
98186
out.writeMap(additionalConfig);
99187
}
188+
out.writeOptionalInt(embeddingDimension);
189+
if (frameworkType != null) {
190+
out.writeBoolean(true);
191+
out.writeEnum(frameworkType);
192+
} else {
193+
out.writeBoolean(false);
194+
}
195+
if (poolingMode != null) {
196+
out.writeBoolean(true);
197+
out.writeEnum(poolingMode);
198+
} else {
199+
out.writeBoolean(false);
200+
}
201+
out.writeBoolean(normalizeResult);
202+
out.writeOptionalInt(modelMaxLength);
203+
out.writeOptionalString(queryPrefix);
204+
out.writeOptionalString(passagePrefix);
100205
}
101206

102207
@Override
@@ -111,10 +216,72 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
111216
if (additionalConfig != null) {
112217
builder.field(ADDITIONAL_CONFIG_FIELD, additionalConfig);
113218
}
219+
if (embeddingDimension != null) {
220+
builder.field(EMBEDDING_DIMENSION_FIELD, embeddingDimension);
221+
}
222+
if (frameworkType != null) {
223+
builder.field(FRAMEWORK_TYPE_FIELD, frameworkType);
224+
}
225+
if (modelMaxLength != null) {
226+
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
227+
}
228+
if (poolingMode != null) {
229+
builder.field(POOLING_MODE_FIELD, poolingMode);
230+
}
231+
if (normalizeResult) {
232+
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
233+
}
234+
if (queryPrefix != null) {
235+
builder.field(QUERY_PREFIX, queryPrefix);
236+
}
237+
if (passagePrefix != null) {
238+
builder.field(PASSAGE_PREFIX, passagePrefix);
239+
}
114240
builder.endObject();
115241
return builder;
116242
}
117243

244+
public enum PoolingMode {
245+
MEAN("mean"),
246+
MEAN_SQRT_LEN("mean_sqrt_len"),
247+
MAX("max"),
248+
WEIGHTED_MEAN("weightedmean"),
249+
CLS("cls"),
250+
LAST_TOKEN("lasttoken");
251+
252+
private String name;
253+
254+
public String getName() {
255+
return name;
256+
}
257+
258+
PoolingMode(String name) {
259+
this.name = name;
260+
}
261+
262+
public static PoolingMode from(String value) {
263+
try {
264+
return PoolingMode.valueOf(value.toUpperCase(Locale.ROOT));
265+
} catch (Exception e) {
266+
throw new IllegalArgumentException("Wrong pooling method");
267+
}
268+
}
269+
}
270+
271+
public enum FrameworkType {
272+
HUGGINGFACE_TRANSFORMERS,
273+
SENTENCE_TRANSFORMERS,
274+
HUGGINGFACE_TRANSFORMERS_NEURON;
275+
276+
public static FrameworkType from(String value) {
277+
try {
278+
return FrameworkType.valueOf(value.toUpperCase(Locale.ROOT));
279+
} catch (Exception e) {
280+
throw new IllegalArgumentException("Wrong framework type");
281+
}
282+
}
283+
}
284+
118285
protected void validateNoDuplicateKeys(String allConfig, Map<String, Object> additionalConfig) {
119286
if (allConfig == null || additionalConfig == null || additionalConfig.isEmpty()) {
120287
return;

0 commit comments

Comments
 (0)