Skip to content

Commit 6ac49b3

Browse files
committed
fix typo of train script
1 parent a4a2336 commit 6ac49b3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train/src/ImageReward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def encode_pair(self, batch_data):
146146
encoder_attention_mask = image_atts_better,
147147
return_dict = True,
148148
).last_hidden_state # [batch_size, seq_len, feature_dim]
149-
emb_better = emb_better[:, -1, :].float()
149+
emb_better = emb_better[:, 0, :].float()
150150

151151
# encode worse emb
152152
image_embeds_worse = self.blip.visual_encoder(img_worse)
@@ -157,7 +157,7 @@ def encode_pair(self, batch_data):
157157
encoder_attention_mask = image_atts_worse,
158158
return_dict = True,
159159
).last_hidden_state
160-
emb_worse = emb_worse[:, -1, :].float()
160+
emb_worse = emb_worse[:, 0, :].float()
161161

162162
# get batch data
163163
batch_data = {

0 commit comments

Comments
 (0)