You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def call(self, q_reps, p_reps):
if self.negatives_cross_device:
# This gathers both negatives and positives.
# It could likely be optimized by only gathering negatives.
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
I have some sense each query refers to some samples,use divided to count the num of samples,and use arrange with multiply to find the positive item index. Maybe is that?
def call(self, q_reps, p_reps):
if self.negatives_cross_device:
# This gathers both negatives and positives.
# It could likely be optimized by only gathering negatives.
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
in the code,does it use ContrastiveLoss following the paper?
The text was updated successfully, but these errors were encountered: