Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Support Tied-Augment #1828

Open
ekurtulus opened this issue May 25, 2023 · 3 comments
Open

[FEATURE] Support Tied-Augment #1828

ekurtulus opened this issue May 25, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@ekurtulus
Copy link

ekurtulus commented May 25, 2023

Recently, we introduced Tied-Augment, a simple framework that combines self-supervised learning learning and supervised learning by making forward passes on two augmented views of the data with tied (shared) weights. In addition to the classification loss, it adds a similarity term to enforce invariance between the features of the augmented views. We found that our framework can be used to improve the effectiveness of both simple flips-and-crops (Crop-Flip) and aggressive augmentations (RandAugment) even for few-epoch training. As the effect of data augmentation is amplified, the sample efficiency of the data increases.

I believe Tied-Augment would be a nice addition to Timm training script. It can significantly improve mixup/RandAugment (77.6% → 79.6%) with marginal extra cost. Here is my reference implementation.

@ekurtulus ekurtulus added the enhancement New feature or request label May 25, 2023
@pdedeler
Copy link

👍🏻 It would be great if you can implement Tied-Augment

@rwightman
Copy link
Collaborator

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

@ekurtulus
Copy link
Author

ekurtulus commented May 29, 2023

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

@rwightman Yes, however, the only difference is that Tied-Augment requires the features of the augmented views. Therefore, an additional wrapper has to be put for the model as well.

Example (for a Timm model with num_classes=0)

class TimmWrapper(nn.Module):
    def __init__(self, model, num_classes):
        super(TimmWrapper, self).__init__()
        self.model = model
        self.fc = nn.Linear(model.num_features, num_classes)
    
    def forward(self, x, return_features=False):
        if self.training or return_features:
            features = self.model(x)
            logits = self.fc(features)
            return features, logits
        else:
            return self.fc(self.model(x))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants