Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnglE loss #2471

Merged
merged 10 commits into from
Feb 14, 2024
Merged

AnglE loss #2471

merged 10 commits into from
Feb 14, 2024

Conversation

johneckberg
Copy link
Contributor

PR Overview:

Details:

  • @ir2718 pairwise_angle_sim is implemented in utils and has the same method signature as pairwise_cos_sim.
  • I'm not happy with the readability of utils.pairwise_angle_sim as the math requires a handful of temp variables. If anyone has suggestions on this please leave a comment!
  • @SeanLee97 if you have the time to review this, please comment if you see any issues!

@johneckberg johneckberg marked this pull request as draft February 5, 2024 20:57
@johneckberg
Copy link
Contributor Author

@tomaarsen Do you have an idea of why only the ubuntu unit tests are failing?

@SeanLee97
Copy link

@johneckberg, many thanks for your implementation! Could it combine with contrastive loss?

@tomaarsen
Copy link
Collaborator

tomaarsen commented Feb 6, 2024

@johneckberg I'm not sure, no. The logs are very confusing too, it seems like the runners just die. I'll investigate it more.
EDIT: They run out of disk space - maybe an issue on the GitHub side, we can ignore it for now.

@SeanLee97 do you mean the regular cosine loss (the cosent one) and in-batch negatives? I think maybe it makes most sense to have AnglELoss as only the angle-optimized objective, and to allow users to mix and match losses themselves to reproduce your final loss function?

  • Tom Aarsen

@johneckberg johneckberg marked this pull request as ready for review February 6, 2024 14:18
@johneckberg
Copy link
Contributor Author

@tomaarsen that's strange, thanks for the clarification!

@SeanLee97, there is an open issue #2440 regarding combining losses. Different loss functions in the library require different input formats, so combining contrastive loss/MNR Loss with AnglE Loss is not possible yet. Please leave a comment if you have any ideas for a solution to this!

@tomaarsen
Copy link
Collaborator

Initial tests seem to indicate that AnglE on its own performs slightly worse than just CoSENT, but still notably better than just CosineSimilarityLoss. Combining AnglE + CoSENT + MNRL is possible, but seems to result in worse performance at small-ish batch sizes (128 or 256) than pure AnglE or CoSENT, seemingly because MNRL is not doing great in my STS experiment.

Note that I've only ran this on a few (~4) scripts.

@johneckberg
Copy link
Contributor Author

Hey @tomaarsen! I also noticed this performance differential between CoSENT and AnglE when performing informal tests. This is somewhat visible in the ablation study in the AnglE paper, noting a small (.13%) performance increase when using just CoSENT over just AnglE. How did you combine MNR with AnglE and CoSENT?

@tomaarsen
Copy link
Collaborator

I'll have to dive back into the ablation study!

I used:

class FullAngleLoss(nn.Module):

    def __init__(self, model) -> None:
        super().__init__()
        self.angle_loss = losses.AnglELoss(model=model)
        self.cosent_loss = losses.CoSENTLoss(model=model)
        self.ibn = losses.MultipleNegativesRankingLoss(model=model)

    def forward(self, sentence_features, labels):
        return self.angle_loss(sentence_features, labels) + self.cosent_loss(sentence_features, labels) + self.ibn(sentence_features, labels)

train_loss = FullAngleLoss(model=model)

@tomaarsen
Copy link
Collaborator

Just had another look at the ablation study - the findings mirror mine quite closely!

@SeanLee97
Copy link

SeanLee97 commented Feb 7, 2024

hi @tomaarsen @johneckberg , thanks for testing!

Here are some suggestions from UAE training to achieve good performance:

  • Combine the angle loss with MultipleNegativesRankingLoss (MNRL), without combining it with CoSENT.
  • When working with small batch sizes, consider increasing the weight of MNRL. In the case of UAE training, we set the batch size to 64, with a weight of 20 for MNRL and a weight of 1.0 for the angle loss.

As for NLI (multinli + snli), we just used the entailment (label 1) and contradict (label 0) data for training.

@tomaarsen
Copy link
Collaborator

Very useful information! I will try to run some extra experiments.

@johneckberg
Copy link
Contributor Author

Thanks for the insight @SeanLee97!

@tomaarsen, my understanding is that the ST implementation of MNR loss treats every input pair as a positive pair; is it possible that part of the performance issues on STS are coming from any negative or neutral input pairs being treated as positive pairs inside MNR loss?

@tomaarsen
Copy link
Collaborator

@johneckberg Oh, you're super right. It makes no sense to apply MNRL on the training_stsbenchmark.py example as it doesn't exclusively have anchor-positive pairs or anchor-positive-negative triplets. That's an oversight on my part.

  • Tom Aarsen

@johneckberg
Copy link
Contributor Author

@tomaarsen No worries, glad I could be a second set of eyes!

I have been thinking about that sort of data formatting problem in relation to issue #2440, and can't think of any solid ways around it. In the AnglE repo, @SeanLee97 combines losses by always conforming to the y_true, y_pred input convention. Each input pair is just the first two in y_pred, the third and forth in y_pred, and so on. Then when using MNR loss, a target matrix gets created to filter the pairs by label. This is a good solution, but wouldn't work for ST.

@tomaarsen
Copy link
Collaborator

An interesting solution is to keep the losses separate. E.g. in the current ST codebase that would entail 2 dataloaders & 2 losses (one with AnglE + CoSENT and one with MNRL) that fire round-robin style. This might be less performant, though.

@tomaarsen
Copy link
Collaborator

In training_stsbenchmark.py I've changed up my custom loss function to:

class FullAngleLoss(nn.Module):

    def __init__(self, model) -> None:
        super().__init__()
        self.angle_loss = losses.AnglELoss(model=model)
        self.cosent_loss = losses.CoSENTLoss(model=model)
        self.ibn = losses.MultipleNegativesRankingLoss(model=model)

    def forward(self, sentence_features, labels):
        positive_pairs = [
            {key: value[labels >= 0.8] for key, value in features.items()}
            for features in copy.deepcopy(sentence_features)
        ]
        loss = (
            self.angle_loss(sentence_features, labels)
            + self.cosent_loss(sentence_features, labels)
            + 3 * self.ibn(positive_pairs, labels)
        )
        return loss


train_loss = FullAngleLoss(model=model)

And this reaches a competitive 0.8486 Spearman correlation coefficient on the test set with batch_size of 64.
For comparison, CoSENT + AnglE reaches 0.8419, pure AnglE reaches 0.8371, pure CoSENT reaches 0.8425, pure Cosine reaches 0.7918.

The docstring changes make the docs slightly prettier
@tomaarsen
Copy link
Collaborator

I believe this might be ready. Any last comments or suggestions before I move forward with this @johneckberg @SeanLee97?

  • Tom Aarsen

@johneckberg
Copy link
Contributor Author

@tomaarsen I don't have any!

@SeanLee97
Copy link

I believe this might be ready. Any last comments or suggestions before I move forward with this @johneckberg @SeanLee97?

  • Tom Aarsen

No more comments.

@tomaarsen
Copy link
Collaborator

Much appreciated to you both, this is a very exciting addition.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 914fd6a into UKPLab:master Feb 14, 2024
5 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement AnglE loss
3 participants