Skip to content

Commit bada276

Browse files
committed
fix: serialize vectorizer sourceProperties to properties
1 parent 699c8a6 commit bada276

30 files changed

+1269
-243
lines changed

pom.xml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@
5656
<gson.version>2.13.2</gson.version>
5757
<httpclient.version>5.5.1</httpclient.version>
5858
<lang3.version>3.20.0</lang3.version>
59-
<junit.version>5.13.4</junit.version>
60-
<testcontainers.version>1.21.3</testcontainers.version>
59+
<junit.version>4.13.2</junit.version>
60+
<testcontainers.version>2.0.2</testcontainers.version>
6161
<assertj-core.version>3.27.6</assertj-core.version>
6262
<jparams.version>1.0.4</jparams.version>
6363
<mockito.version>5.20.0</mockito.version>
@@ -134,13 +134,13 @@
134134
</dependency>
135135
<dependency>
136136
<groupId>org.testcontainers</groupId>
137-
<artifactId>weaviate</artifactId>
137+
<artifactId>testcontainers-weaviate</artifactId>
138138
<version>${testcontainers.version}</version>
139139
<scope>test</scope>
140140
</dependency>
141141
<dependency>
142142
<groupId>org.testcontainers</groupId>
143-
<artifactId>minio</artifactId>
143+
<artifactId>testcontainers-minio</artifactId>
144144
<version>${testcontainers.version}</version>
145145
<scope>test</scope>
146146
</dependency>
@@ -150,6 +150,12 @@
150150
<version>${assertj-core.version}</version>
151151
<scope>test</scope>
152152
</dependency>
153+
<dependency>
154+
<groupId>junit</groupId>
155+
<artifactId>junit</artifactId>
156+
<version>${junit.version}</version>
157+
<scope>test</scope>
158+
</dependency>
153159
<dependency>
154160
<groupId>com.jparams</groupId>
155161
<artifactId>jparams-junit4</artifactId>
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package io.weaviate.integration;
2+
3+
import java.io.IOException;
4+
import java.util.Map;
5+
6+
import org.junit.ClassRule;
7+
import org.junit.Test;
8+
import org.junit.rules.TestRule;
9+
10+
import io.weaviate.ConcurrentTest;
11+
import io.weaviate.client6.v1.api.WeaviateClient;
12+
import io.weaviate.client6.v1.api.collections.Property;
13+
import io.weaviate.client6.v1.api.collections.VectorConfig;
14+
import io.weaviate.client6.v1.api.collections.WeaviateObject;
15+
import io.weaviate.containers.Container;
16+
import io.weaviate.containers.Model2Vec;
17+
import io.weaviate.containers.Weaviate;
18+
19+
import static org.assertj.core.api.Assertions.assertThat;
20+
21+
public class VectorizersITest extends ConcurrentTest {
22+
private static final Container.ContainerGroup compose = Container.compose(
23+
Weaviate.custom()
24+
.withModel2VecUrl(Model2Vec.URL)
25+
.build(),
26+
Container.MODEL2VEC);
27+
@ClassRule // Bind containers to the lifetime of the test
28+
public static final TestRule _rule = compose.asTestRule();
29+
private static final WeaviateClient client = compose.getClient();
30+
31+
@Test
32+
public void testVectorizerModel2VecPropeties() throws IOException {
33+
var collectionName = ns("Model2Vec2NamedVectors");
34+
client.collections.create(collectionName,
35+
col -> col
36+
.properties(Property.text("name"), Property.text("author"))
37+
.vectorConfig(
38+
VectorConfig.text2vecModel2Vec("name", v -> v.sourceProperties("name")),
39+
VectorConfig.text2vecModel2Vec("author", v -> v.sourceProperties("author"))
40+
)
41+
);
42+
43+
var model2vec = client.collections.use(collectionName);
44+
assertThat(model2vec).isNotNull();
45+
46+
String uuid1 = "00000000-0000-0000-0000-000000000001";
47+
WeaviateObject<Map<String, Object>> obj1 = WeaviateObject.of(o ->
48+
o.properties(Map.of("name", "Dune", "author", "Frank Herbert")).uuid(uuid1)
49+
);
50+
String uuid2 = "00000000-0000-0000-0000-000000000002";
51+
WeaviateObject<Map<String, Object>> obj2 = WeaviateObject.of(o ->
52+
o.properties(Map.of("name", "same content", "author", "same content")).uuid(uuid2)
53+
);
54+
55+
var resp = model2vec.data.insertMany(obj1, obj2);
56+
assertThat(resp).isNotNull().satisfies(s -> {
57+
assertThat(s.errors()).isEmpty();
58+
});
59+
60+
var o1 = model2vec.query.fetchObjectById(uuid1, FetchObjectById.Builder::includeVector);
61+
// Assert that for object1 we have generated 2 different vectors
62+
assertThat(o1).get()
63+
.extracting(WeaviateObject::vectors)
64+
.satisfies(v -> {
65+
assertThat(v.getSingle("name")).isNotEmpty();
66+
assertThat(v.getSingle("author")).isNotEmpty();
67+
assertThat(v.getSingle("name")).isNotEqualTo(v.getSingle("author"));
68+
});
69+
70+
var o2 = model2vec.query.fetchObjectById(uuid2, FetchObjectById.Builder::includeVector);
71+
// Assert that for object2 we have generated same vectors
72+
assertThat(o2).get()
73+
.extracting(WeaviateObject::vectors)
74+
.satisfies(v -> {
75+
assertThat(v.getSingle("name")).isNotEmpty();
76+
assertThat(v.getSingle("author")).isNotEmpty();
77+
assertThat(v.getSingle("name")).isEqualTo(v.getSingle("author"));
78+
});
79+
}
80+
}

src/main/java/io/weaviate/client6/v1/api/collections/VectorConfig.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ public static Map.Entry<String, VectorConfig> multi2vecCohere(String vectorName,
353353
*
354354
* @param location Geographic region the Google Cloud model runs in.
355355
*/
356-
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String location) {
357-
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, location);
356+
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String projectId, String location) {
357+
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, projectId, location);
358358
}
359359

360360
/**
@@ -364,9 +364,10 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(String location) {
364364
* @param fn Lambda expression for optional parameters.
365365
*/
366366
public static Map.Entry<String, VectorConfig> multi2vecGoogle(
367+
String projectId,
367368
String location,
368369
Function<Multi2VecGoogleVectorizer.Builder, ObjectBuilder<Multi2VecGoogleVectorizer>> fn) {
369-
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, location, fn);
370+
return multi2vecGoogle(VectorIndex.DEFAULT_VECTOR_NAME, projectId, location, fn);
370371
}
371372

372373
/**
@@ -375,8 +376,8 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(
375376
* @param vectorName Vector name.
376377
* @param location Geographic region the Google Cloud model runs in.
377378
*/
378-
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName, String location) {
379-
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(location));
379+
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName, String projectId, String location) {
380+
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(projectId, location));
380381
}
381382

382383
/**
@@ -387,9 +388,9 @@ public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName,
387388
* @param fn Lambda expression for optional parameters.
388389
*/
389390
public static Map.Entry<String, VectorConfig> multi2vecGoogle(String vectorName,
390-
String location,
391+
String projectId, String location,
391392
Function<Multi2VecGoogleVectorizer.Builder, ObjectBuilder<Multi2VecGoogleVectorizer>> fn) {
392-
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(location, fn));
393+
return Map.entry(vectorName, Multi2VecGoogleVectorizer.of(projectId, location, fn));
393394
}
394395

395396
/** Create a vector index with an {@code multi2vec-jinaai} vectorizer. */

src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.weaviate.client6.v1.api.collections.vectorizers;
22

3-
import java.util.ArrayList;
43
import java.util.Arrays;
54
import java.util.List;
65
import java.util.function.Function;
@@ -46,7 +45,7 @@ public static class Builder implements ObjectBuilder<Img2VecNeuralVectorizer> {
4645
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
4746
private Quantization quantization;
4847

49-
private List<String> imageFields = new ArrayList<>();
48+
private List<String> imageFields;
5049

5150
/** Add BLOB properties to include in the embedding. */
5251
public Builder imageFields(List<String> fields) {

src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2MultiVecJinaAiVectorizer.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package io.weaviate.client6.v1.api.collections.vectorizers;
22

3-
import java.util.ArrayList;
43
import java.util.Arrays;
54
import java.util.List;
65
import java.util.function.Function;
@@ -52,12 +51,12 @@ public static class Builder implements ObjectBuilder<Multi2MultiVecJinaAiVectori
5251
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
5352
private Quantization quantization;
5453

55-
private final List<String> imageFields = new ArrayList<>();
56-
private final List<String> textFields = new ArrayList<>();
54+
private List<String> imageFields;
55+
private List<String> textFields;
5756

5857
/** Add BLOB properties to include in the embedding. */
5958
public Builder imageFields(List<String> fields) {
60-
imageFields.addAll(fields);
59+
imageFields = fields;
6160
return this;
6261
}
6362

@@ -68,7 +67,7 @@ public Builder imageFields(String... fields) {
6867

6968
/** Add TEXT properties to include in the embedding. */
7069
public Builder textFields(List<String> fields) {
71-
textFields.addAll(fields);
70+
textFields = fields;
7271
return this;
7372
}
7473

src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecAwsVectorizer.java

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package io.weaviate.client6.v1.api.collections.vectorizers;
22

3+
import java.util.ArrayList;
34
import java.util.Arrays;
4-
import java.util.LinkedHashMap;
55
import java.util.List;
6-
import java.util.Map;
76
import java.util.function.Function;
87

98
import com.google.gson.annotations.SerializedName;
@@ -64,11 +63,9 @@ public Multi2VecAwsVectorizer(Builder builder) {
6463
builder.model,
6564
builder.dimensions,
6665
builder.region,
67-
builder.imageFields.keySet().stream().toList(),
68-
builder.textFields.keySet().stream().toList(),
69-
new Weights(
70-
builder.imageFields.values().stream().toList(),
71-
builder.textFields.values().stream().toList()),
66+
builder.imageFields,
67+
builder.textFields,
68+
builder.getWeights(),
7269
builder.vectorIndex,
7370
builder.quantization);
7471
}
@@ -77,8 +74,10 @@ public static class Builder implements ObjectBuilder<Multi2VecAwsVectorizer> {
7774
private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX;
7875
private Quantization quantization;
7976

80-
private Map<String, Float> imageFields = new LinkedHashMap<>();
81-
private Map<String, Float> textFields = new LinkedHashMap<>();
77+
private List<String> imageFields;
78+
private List<Float> imageWeights;
79+
private List<String> textFields;
80+
private List<Float> textWeights;
8281

8382
private String model;
8483
private Integer dimensions;
@@ -101,7 +100,7 @@ public Builder region(String region) {
101100

102101
/** Add BLOB properties to include in the embedding. */
103102
public Builder imageFields(List<String> fields) {
104-
fields.forEach(field -> imageFields.put(field, null));
103+
this.imageFields = fields;
105104
return this;
106105
}
107106

@@ -117,13 +116,20 @@ public Builder imageFields(String... fields) {
117116
* @param weight Custom weight between 0.0 and 1.0.
118117
*/
119118
public Builder imageField(String field, float weight) {
120-
imageFields.put(field, weight);
119+
if (this.imageFields == null) {
120+
this.imageFields = new ArrayList<>();
121+
}
122+
if (this.imageWeights == null) {
123+
this.imageWeights = new ArrayList<>();
124+
}
125+
this.imageFields.add(field);
126+
this.imageWeights.add(weight);
121127
return this;
122128
}
123129

124130
/** Add TEXT properties to include in the embedding. */
125131
public Builder textFields(List<String> fields) {
126-
fields.forEach(field -> textFields.put(field, null));
132+
this.textFields = fields;
127133
return this;
128134
}
129135

@@ -139,10 +145,24 @@ public Builder textFields(String... fields) {
139145
* @param weight Custom weight between 0.0 and 1.0.
140146
*/
141147
public Builder textField(String field, float weight) {
142-
textFields.put(field, weight);
148+
if (this.textFields == null) {
149+
this.textFields = new ArrayList<>();
150+
}
151+
if (this.textWeights == null) {
152+
this.textWeights = new ArrayList<>();
153+
}
154+
this.textFields.add(field);
155+
this.textWeights.add(weight);
143156
return this;
144157
}
145158

159+
protected Weights getWeights() {
160+
if (this.textWeights != null || this.imageWeights != null) {
161+
return new Weights(this.imageWeights, this.textWeights);
162+
}
163+
return null;
164+
}
165+
146166
/**
147167
* Override default vector index configuration.
148168
*

0 commit comments

Comments
 (0)