- First, please take a look at the examples
-
Choose a target setting from the tree (See the "Available Settings" section below).
-
Create a new subclass of
Method
, with the chosen target setting.Your class should implement the following methods:
fit(train_env, valid_env)
get_actions(observations, action_space) -> Actions
The following methods are optional, but can be very useful to help customize how your method is used at train/test time:
configure(setting: Setting)
on_task_switch(task_id: Optional[int])
test(test_env)
class MyNewMethod(Method, target_setting=ClassIncrementalSetting): ... # Your code here. def fit(self, train_env: DataLoader, valid_env: DataLoader): # Train your model however you want here. self.trainer.fit( self.model, train_dataloader=train_env, val_dataloaders=valid_env, ) def get_actions(self, observations: Observations, observation_space: gym.Space) -> Actions: # Return an "Action" (prediction) for the given observations. # Each Setting has its own Observations, Actions and Rewards types, # which are based on those of their parents. return self.model.predict(observations.x) def on_task_switch(self, task_id: Optional[int]): #This method gets called if task boundaries are known in the current #setting. Furthermore, if task labels are available, task_id will be # the index of the new task. If not, task_id will be None. # For example, you could do something like this: self.model.current_output_head = self.model.output_heads[task_id]
-
Running / Debugging your method:
(at the bottom of your script, for example)
if __name__ == "__main__": ## 1. Create the setting you want to apply your method on. # First option: Create the Setting directly in code: setting = ClassIncrementalSetting(dataset="cifar10", nb_tasks=5) # Second option: Create the Setting from the command-line: setting = ClassIncrementalSetting.from_args() ## 2. Create your Method, however you want. my_method = MyNewMethod() ## 3. Apply your method on the setting to obtain results. results = setting.apply(my_method) # Optionally, display the results. print(results.summary()) results.make_plots()
-
(WIP): Adding your new method to the tree:
-
Place the script/package that defines your Method inside of the
methods
folder. -
Add the
@register_method
decorator to your Method definition, for example:from sequoia.methods import register_method @register_method class MyNewMethod(Method, target_setting=ClassIncrementalSetting): name: ClassVar[str] = "my_new_method" ...
-
To launch an experiment using your method, run the following command:
python main.py --setting <some_setting_name> --method my_new_method
To customize how your method gets created from the command-line, override the two following class methods:
add_argparse_args(cls, parser: ArgumentParser)
from_argparse_args(cls, args: Namespace) -> Method
-
Create a
<your_method_script_name>_test.py
file next to your method script. In it, write unit tests for every module/component used in your Method. Have them be easy to read so people can ideally understand how the components of your Method work by simply reading the tests.- (WIP) To run the unittests locally, use the following command:
pytest methods/my_new_method_test.py
- (WIP) To run the unittests locally, use the following command:
-
Then, write a functional test that demonstrates how your new method should behave, and what kind of results it expects to produce. The easiest way to do this is to implement a
validate_results(setting: Setting, results: Results)
method.- (WIP) To debug/run the "integration tests" locally, use the following command:
pytest -x methods/my_new_method_test.py --slow
- (WIP) To debug/run the "integration tests" locally, use the following command:
-
Create a Pull Request, and you're good to go!
-
-
- Target setting: Setting
Versatile Baseline method which targets all settings.
Uses pytorch-lightning's Trainer for training and a LightningModule as a model.
Uses a BaseModel, which can be used for:
- Self-Supervised training with modular auxiliary tasks;
- Semi-Supervised training on partially labeled batches;
- Multi-Head prediction (e.g. in task-incremental scenario);
-
- Target setting: Setting
Baseline method that gives random predictions for any given setting.
This method doesn't have a model or any parameters. It just returns a random action for every observation.
-
- Target setting: IncrementalAssumption
PNN Method.
Applicable to both RL and SL Settings, as long as there are clear task boundaries during training (IncrementalAssumption).
-
- Target setting: ContinualSLSetting
Average Gradient Episodic Memory (AGEM) strategy from Avalanche. See AGEM plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
AR1 strategy from Avalanche. See AR1 plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
CWRStar strategy from Avalanche. See CWRStar plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
Elastic Weight Consolidation (EWC) strategy from Avalanche. See EWC plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
Gradient Episodic Memory (GEM) strategy from Avalanche. See GEM plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
GDumb strategy from Avalanche. See GDumbPlugin for more details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
Learning without Forgetting strategy from Avalanche. See LwF plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
Replay strategy from Avalanche. See Replay plugin for details. This strategy does not use task identities.
See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualSLSetting
The Synaptic Intelligence strategy from Avalanche.
This is the Synaptic Intelligence PyTorch implementation of the algorithm described in the paper "Continuous Learning in Single-Incremental-Task Scenarios" (https://arxiv.org/abs/1806.08568)
The original implementation has been proposed in the paper "Continual Learning Through Synaptic Intelligence" (https://arxiv.org/abs/1703.04200).
The Synaptic Intelligence regularization can also be used in a different strategy by applying the :class:
SynapticIntelligencePlugin
plugin.See the parent class
AvalancheMethod
for the other hyper-parameters and methods. -
- Target setting: ContinualRLSetting
Method that uses the A2C model from stable-baselines3.
-
- Target setting: ContinualRLSetting
Method that uses a DQN model from the stable-baselines3 package.
-
- Target setting: ContinualRLSetting
Method that uses the DDPG model from stable-baselines3.
-
- Target setting: ContinualRLSetting
Method that uses the TD3 model from stable-baselines3.
-
- Target setting: ContinualRLSetting
Method that uses the SAC model from stable-baselines3.
-
- Target setting: ContinualRLSetting
Method that uses the PPO model from stable-baselines3.
-
- Target setting: IncrementalAssumption
Subclass of the BaseMethod, which adds the EWCTask to the
BaseModel
.This Method is applicable to any CL setting (RL or SL) where there are clear task boundaries, regardless of if the task labels are given or not.
-
- Target setting: IncrementalSLSetting
Simple method that uses a replay buffer to reduce forgetting.
-
- Target setting: TaskIncrementalSLSetting
Hard Attention to the Task
@inproceedings{serra2018overcoming, title={Overcoming Catastrophic Forgetting with Hard Attention to the Task}, author={Serra, Joan and Suris, Didac and Miron, Marius and Karatzoglou, Alexandros}, booktitle={International Conference on Machine Learning}, pages={4548--4557}, year={2018} }