Skip to content

Commit a624f0c

Browse files
authored
fix MatryoshkaLoss bug: sort sampled dimension indices to maintain descending dimension order (#3203)
When randomly sampling dimension indices for MatryoshkaLoss, ensure they are processed in ascending order (descending dimension). Previously, when a smaller dimension was be processed first, the ForwardDecorator would cache the output tensor and then shrink it. Since the cached tensor and the shrunk output reference the same underlying memory, the cache would contain truncated embeddings, making them unusable for subsequent larger dimensions.
1 parent e2a0098 commit a624f0c

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

sentence_transformers/losses/MatryoshkaLoss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
218218
dim_indices = range(len(self.matryoshka_dims))
219219
if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
220220
dim_indices = random.sample(dim_indices, self.n_dims_per_step)
221+
dim_indices.sort()
221222

222223
loss = 0.0
223224
for idx in dim_indices:

0 commit comments

Comments
 (0)