diff --git a/avalanche/training/plugins/__init__.py b/avalanche/training/plugins/__init__.py index 005b342c1..2dd57ac4f 100644 --- a/avalanche/training/plugins/__init__.py +++ b/avalanche/training/plugins/__init__.py @@ -9,4 +9,5 @@ ExperienceBalancedStoragePolicy from .strategy_plugin import StrategyPlugin from .synaptic_intelligence import SynapticIntelligencePlugin +from .siw import SIWPlugin from .cope import CoPEPlugin, PPPloss diff --git a/avalanche/training/plugins/siw.py b/avalanche/training/plugins/siw.py new file mode 100644 index 000000000..1d71a471a --- /dev/null +++ b/avalanche/training/plugins/siw.py @@ -0,0 +1,149 @@ +from torch.utils.data import random_split, ConcatDataset +from avalanche.benchmarks.utils import AvalancheConcatDataset +from avalanche.training.plugins.strategy_plugin import StrategyPlugin +from avalanche.benchmarks.utils.data_loader import \ + MultiTaskJoinedBatchDataLoader +import torch +import torch.cuda as tc +from torch.autograd import Variable +import torch.nn as nn +from avalanche.training.utils import get_last_fc_layer, get_layer_by_name +from typing import Optional +from torch.nn import Linear + + +class SIWPlugin(StrategyPlugin): + """ + Standardization of Initial Weights (SIW) plugin. + From https://arxiv.org/pdf/2008.13710.pdf + + Performs past class initial weights replay and state-level score + calibration. The callbacks `before_training_exp`, `after_backward`, + `after_training_exp`,`before_eval_exp`, and `after_eval_forward` + are implemented. + + The `before_training_exp` callback is implemented in order to keep + track of the classes in each experience + + The `after_backward` callback is implemented in order to freeze past + class weights in the last fully connected layer + + The `after_training_exp` callback is implemented in order to extract + new class images' scores and compute the model confidence at + each incremental state. + + The `before_eval_exp` callback is implemented in order to standardize + initial weights before inference + + The`after_eval_forward` is implemented in order to apply state-level + calibration at the inference time + + The :siw_layer_name: parameter concerns the name of the last fully + connected layer of the network + + The :batch_size: and :num_workers: parameters concern the new class + scores extraction. + """ + + def __init__(self, model, siw_layer_name='fc', batch_size=32, + num_workers=0): + super().__init__() + self.confidences = [] + self.classes_per_experience = [] + self.model = model + self.siw_layer_name = siw_layer_name + self.num_workers = num_workers + self.batch_size = batch_size + + def get_siw_layer(self) -> Optional[Linear]: + result = None + if self.siw_layer_name is None: + last_fc = get_last_fc_layer(self.model) + if last_fc is not None: + result = last_fc[1] + else: + result = get_layer_by_name(self.model, self.siw_layer_name) + return result + + def before_training_exp(self, strategy, **kwargs): + """ + Keep track of the classes encountered in each experience + """ + self.classes_per_experience.append( + strategy.experience.classes_in_this_experience) + + def after_backward(self, strategy, **kwargs): + """ + Before executing the optimization step to perform + back-propagation, we zero the gradients of past class + weights and bias. This is equivalent to freeze past + class weights and bias, to let only the feature extractor + and the new class weights and bias evolve + """ + previous_classes = len(strategy.experience.previous_classes) + last_layer = self.get_siw_layer() + if last_layer is None: + raise RuntimeError('Can\'t find this Linear layer') + + last_layer.weight.grad[:previous_classes, :] = 0 + last_layer.bias.grad[:previous_classes] = 0 + + @torch.no_grad() + def after_training_exp(self, strategy, **kwargs): + """ + Before evaluating the performance of our model, + we extract new class images' scores and compute the + model's confidence at each incremental state + """ + # extract training scores + strategy.model.eval() + + dataset = strategy.experience.dataset + loader = torch.utils.data.DataLoader( + dataset, batch_size=self.batch_size, + num_workers=self.num_workers) + + # compute model's confidence + max_top1_scores = [] + for i, data in enumerate(loader): + inputs, targets, task_labels = data + if tc.is_available(): + inputs = inputs.to(strategy.device) + logits = strategy.model(inputs) + max_score = torch.max(logits, dim=1)[0].tolist() + max_top1_scores.extend(max_score) + self.confidences.append(sum(max_top1_scores) / + len(max_top1_scores)) + + @torch.no_grad() + def before_eval_exp(self, strategy, **kwargs): + """ + Standardize all class weights (by subtracting their mean + and dividing by their standard deviation) + """ + + # standardize last layer weights + last_layer = self.get_siw_layer() + if last_layer is None: + raise RuntimeError('Can\'t find this Linear layer') + + classes_seen_so_far = len(strategy.experience.classes_seen_so_far) + + for i in range(classes_seen_so_far): + mu = torch.mean(last_layer.weight[i]) + std = torch.std(last_layer.weight[i]) + + last_layer.weight.data[i] -= mu + last_layer.weight.data[i] /= std + + def after_eval_forward(self, strategy, **kwargs): + """ + Rectify past class scores by multiplying them by the model's + confidence in the current state and dividing them by the + model's confidence in the initial state in which a past + class was encountered for the first time + """ + for exp in range(len(self.confidences)): + strategy.logits[:, self.classes_per_experience[exp]] *= \ + self.confidences[strategy.experience.current_experience] \ + / self.confidences[exp] diff --git a/avalanche/training/strategies/strategy_wrappers.py b/avalanche/training/strategies/strategy_wrappers.py index 2ef277f41..873512abc 100644 --- a/avalanche/training/strategies/strategy_wrappers.py +++ b/avalanche/training/strategies/strategy_wrappers.py @@ -16,7 +16,7 @@ from avalanche.training import default_logger from avalanche.training.plugins import StrategyPlugin, CWRStarPlugin, \ ReplayPlugin, GDumbPlugin, LwFPlugin, AGEMPlugin, GEMPlugin, EWCPlugin, \ - EvaluationPlugin, SynapticIntelligencePlugin, CoPEPlugin + EvaluationPlugin, SynapticIntelligencePlugin, SIWPlugin, CoPEPlugin from avalanche.training.strategies.base_strategy import BaseStrategy @@ -450,6 +450,53 @@ def __init__(self, model: Module, optimizer: Optimizer, criterion, ) +class SIW(BaseStrategy): + def __init__(self, model: Module, optimizer: Optimizer, criterion, + siw_layer_name: str = 'fc', + batch_size: int = 32, num_workers: int = 0, + train_mb_size: int = 1, train_epochs: int = 1, + eval_mb_size: int = None, device=None, + plugins: Optional[List[StrategyPlugin]] = None, + evaluator: EvaluationPlugin = default_logger, eval_every=-1): + """ Standardization of Initial Weights (SIW) strategy. + See SIW plugin for details. + This strategy does not use task identities. + + :param model: The model. + :param optimizer: The optimizer to use. + :param criterion: The loss criterion to use. + :param siw_layer_name: The name of the last fully connected layer + :param num_workers: The number of workers used to load batches + :param batch_size: The batch size used to extract scores + :param train_mb_size: The train minibatch size. Defaults to 1. + :param train_epochs: The number of training epochs. Defaults to 1. + :param eval_mb_size: The eval minibatch size. Defaults to 1. + :param device: The device to use. Defaults to None (cpu). + :param plugins: Plugins to be added. Defaults to None. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. + if -1: no evaluation during training. + if 0: calls `eval` after the final epoch of each training + experience. + if >0: calls `eval` every `eval_every` epochs and at the end + of all the epochs for a single experience. + """ + + siw = SIWPlugin(model, siw_layer_name, batch_size, num_workers) + if plugins is None: + plugins = [siw] + else: + plugins.append(siw) + + super().__init__( + model, optimizer, criterion, + train_mb_size=train_mb_size, train_epochs=train_epochs, + eval_mb_size=eval_mb_size, device=device, plugins=plugins, + evaluator=evaluator, eval_every=eval_every) + + class CoPE(BaseStrategy): def __init__(self, model: Module, optimizer: Optimizer, criterion, @@ -514,5 +561,6 @@ def __init__(self, model: Module, optimizer: Optimizer, criterion, 'GEM', 'EWC', 'SynapticIntelligence', + 'SIW', 'CoPE' ] diff --git a/examples/siw_cifar100.py b/examples/siw_cifar100.py new file mode 100644 index 000000000..66cf49c1f --- /dev/null +++ b/examples/siw_cifar100.py @@ -0,0 +1,155 @@ +from avalanche.benchmarks.classic import SplitCIFAR100 +from torch.optim import SGD +from torch.nn import CrossEntropyLoss +from avalanche.training.strategies import Naive +from avalanche.training.plugins import SIWPlugin,\ + EvaluationPlugin, StrategyPlugin +from avalanche.logging import InteractiveLogger +from avalanche.evaluation.metrics import accuracy_metrics +import torchvision +import torchvision.transforms as transforms +import torch.nn as nn +import torch +import argparse +from torch.optim import lr_scheduler + + +class LRSchedulerPlugin(StrategyPlugin): + def __init__(self, lr_scheduler): + super().__init__() + self.lr_scheduler = lr_scheduler + + def after_training_epoch(self, strategy: 'BaseStrategy', **kwargs): + self.lr_scheduler.step(strategy.loss.cpu().data.numpy()) + lr = strategy.optimizer.param_groups[0]['lr'] + print(f"\nlr = {lr}") + + +class SetIncrementalHyperParams(StrategyPlugin): + def __init__(self, inc_exp_epochs, inc_exp_patience, first_exp_lr, + lr_decay): + super().__init__() + self.inc_exp_epochs = inc_exp_epochs + self.inc_exp_patience = inc_exp_patience + self.first_exp_lr = first_exp_lr + self.lr_decay = lr_decay + + def before_training_exp(self, strategy: 'BaseStrategy', **kwargs): + if strategy.experience.current_experience > 0: # incremental update + strategy.train_epochs = self.inc_exp_epochs + strategy.optimizer.param_groups[0]['lr'] = \ + self.first_exp_lr / strategy.experience.current_experience + strategy.scheduler = LRSchedulerPlugin( + lr_scheduler.ReduceLROnPlateau(strategy.optimizer, + patience=self.inc_exp_patience, + factor=self.lr_decay)) + + +def main(args): + # check if selected GPU is available or use CPU + assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0." + device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() + and args.cuda >= 0 else "cpu") + print(f'Using device: {device}') + ############################################# + model = torchvision.models.resnet18(num_classes=100).to(device) + + # print to stdout + interactive_logger = InteractiveLogger() + + eval_plugin = EvaluationPlugin( + accuracy_metrics(minibatch=False, epoch=True, experience=True, + stream=True), + loggers=[interactive_logger] + ) + + optimizer = SGD(model.parameters(), lr=args.first_exp_lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + criterion = CrossEntropyLoss() + scheduler = LRSchedulerPlugin( + lr_scheduler.ReduceLROnPlateau(optimizer, + patience=args.first_exp_patience, + factor=args.lr_decay)) + incremental_params = SetIncrementalHyperParams(args.inc_exp_epochs, + args.inc_exp_patience, + args.first_exp_lr, + args.lr_decay) + + siw = SIWPlugin(model, siw_layer_name=args.siw_layer_name, + batch_size=args.eval_batch_size, + num_workers=args.num_workers) + + strategy = Naive(model, optimizer, criterion, + device=device, train_epochs=args.first_exp_epochs, + evaluator=eval_plugin, + plugins=[siw, scheduler, incremental_params], + train_mb_size=args.train_batch_size, + eval_mb_size=args.eval_batch_size) + + normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], + std=[0.2673, 0.2564, 0.2762]) + + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize]) + + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize]) + + # scenario + scenario = SplitCIFAR100(n_experiences=10, return_task_id=False, + fixed_class_order=range(0, 100), + train_transform=train_transform, + eval_transform=test_transform) + # TRAINING LOOP + print('Starting experiment...') + results = [] + for i, experience in enumerate(scenario.train_stream): + print("Start of experience: ", experience.current_experience) + strategy.train(experience, num_workers=args.num_workers) + print('Training completed') + print('Computing accuracy on the test set') + res = strategy.eval(scenario.test_stream[:i + 1], + num_workers=args.num_workers) + results.append(res) + + print('Results = ' + str(results)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--first_exp_lr', type=float, default=0.1, + help='Learning rate for the first experience.') + parser.add_argument('--momentum', type=float, default=0.9, + help='Momentum') + parser.add_argument('--weight_decay', type=float, default=0.0005, + help='Weight decay') + parser.add_argument('--lr_decay', type=float, default=0.1, + help='LR decay') + parser.add_argument('--first_exp_patience', type=int, default=60, + help='Patience in the first experience') + parser.add_argument('--inc_exp_patience', type=int, default=15, + help='Patience in the incremental experiences') + parser.add_argument('--first_exp_epochs', type=int, default=300, + help='Number of epochs in the first experience.') + parser.add_argument('--inc_exp_epochs', type=int, default=70, + help='Number of epochs in each incremental experience.') + parser.add_argument('--train_batch_size', type=int, default=128, + help='Training batch size.') + parser.add_argument('--eval_batch_size', type=int, default=32, + help='Evaluation batch size.') + parser.add_argument('--num_workers', type=int, default=8, + help='Number of workers used to extract scores.') + parser.add_argument('--siw_layer_name', type=str, default='fc', + help='Name of the last fully connected layer.') + parser.add_argument('--cuda', type=int, default=1, + help='Specify GPU id to use. Use CPU if -1.') + args = parser.parse_args() + + main(args) diff --git a/tests/test_strategies.py b/tests/test_strategies.py index dac5e179b..71842f0b1 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -26,7 +26,8 @@ from avalanche.training.plugins import EvaluationPlugin from avalanche.training.strategies import Naive, Replay, CWRStar, \ GDumb, LwF, AGEM, GEM, EWC, \ - SynapticIntelligence, JointTraining, CoPE + SynapticIntelligence, JointTraining, SIW, CoPE + from avalanche.training.strategies.ar1 import AR1 from avalanche.training.strategies.cumulative import Cumulative from avalanche.benchmarks import nc_benchmark @@ -327,6 +328,34 @@ def test_ar1(self): rm_sz=200) self.run_strategy(my_nc_benchmark, strategy) + def run_siw(self, scenario, cl_strategy): + print('Starting experiment...') + cl_strategy.evaluator.loggers = [TextLogger(sys.stdout)] + results = [] + for i, train_batch_info in enumerate(scenario.train_stream): + print("Start of experience ", train_batch_info.current_experience) + + cl_strategy.train(train_batch_info) + print('Training completed') + + print('Computing accuracy on the current test set') + results.append(cl_strategy.eval(scenario.test_stream[:i+1])) + + def test_siw(self): + # SIT scenario + model, optimizer, criterion, my_nc_scenario = self.init_sit() + strategy = SIW(model, optimizer, criterion, siw_layer_name='classifier', + batch_size=32, num_workers=8, train_mb_size=128, + device=self.device, eval_mb_size=32, train_epochs=2) + self.run_siw(my_nc_scenario, strategy) + + # MT scenario + strategy = SIW(model, optimizer, criterion, siw_layer_name='classifier', + batch_size=32, num_workers=8, train_mb_size=128, + device=self.device, eval_mb_size=32, train_epochs=2) + scenario = self.load_scenario(use_task_labels=False) + self.run_siw(scenario, strategy) + def load_ar1_scenario(self): """ Returns a NC Scenario from a fake dataset of 10 classes, 5 experiences,