We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a4a2336 commit 6ac49b3Copy full SHA for 6ac49b3
train/src/ImageReward.py
@@ -146,7 +146,7 @@ def encode_pair(self, batch_data):
146
encoder_attention_mask = image_atts_better,
147
return_dict = True,
148
).last_hidden_state # [batch_size, seq_len, feature_dim]
149
- emb_better = emb_better[:, -1, :].float()
+ emb_better = emb_better[:, 0, :].float()
150
151
# encode worse emb
152
image_embeds_worse = self.blip.visual_encoder(img_worse)
@@ -157,7 +157,7 @@ def encode_pair(self, batch_data):
157
encoder_attention_mask = image_atts_worse,
158
159
).last_hidden_state
160
- emb_worse = emb_worse[:, -1, :].float()
+ emb_worse = emb_worse[:, 0, :].float()
161
162
# get batch data
163
batch_data = {
0 commit comments