Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Ray tune with the base trainer class #86

Open
shrave opened this issue Apr 11, 2023 · 1 comment
Open

Integrate Ray tune with the base trainer class #86

shrave opened this issue Apr 11, 2023 · 1 comment
Labels
feature request question Further information is requested

Comments

@shrave
Copy link

shrave commented Apr 11, 2023

Hi,

I was wondering if I could include the ray tune (hyper-parameter search) library as either a callback or in the base trainer class to look for the right hyper-parameters for a model and even stop early.

Can you please tell me how it is possible to integrate it and thereby stop the training midway in case a particular hyper-parameter configuration does not give good performance?

Even if you could suggest a way to just return the logger at every epoch when a training pipeline instance is called, then my job would be done.

This library has been extremely useful in my research. Thank you very much!

@clementchadebec
Copy link
Owner

clementchadebec commented Apr 13, 2023

Hello @shrave,

Thank you for kind words. I am happy to hear that this repo is useful for your research.

As to the issue, I have never really used ray but from what I understand from the provided tutorials, I think that ray-tune can be included pretty straightforwardly using a callback as you suggest.

  1. The callback can be created as follows. The callback should be able to read the metrics at the end of the epoch and store them in the tune report.
from pythae.trainers.training_callbacks import TrainingCallback

class RayCallback(TrainingCallback):

    def __init__(self) -> None:
        super().__init__()

    def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
        metrics = kwargs.pop("metrics") # get the metrics during training
        tune.report(eval_epoch_loss=metrics["eval_epoch_loss"]) # add the metric to monitor in the report
  1. You will need to wrap the training part of your script in a method that will then be called by the ray Tuner. The input config is expected to be the search_space dictionary defining the range of the hyper-parameters considered. All the rest is similar to a classic training configuration and launching with pythae.
def train_ray(config):

    mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

    train_dataset = BaseDataset(mnist_trainset.data[:1000].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
    eval_dataset = BaseDataset(mnist_trainset.data[-1000:].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))

    my_training_config = BaseTrainerConfig(
       output_dir='my_model',
       num_epochs=50,
       learning_rate=config["lr"], # pass the lr for hp search
       per_device_train_batch_size=200,
       per_device_eval_batch_size=200,
       steps_saving=None,
       optimizer_cls="AdamW",
       optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
       scheduler_cls="ReduceLROnPlateau",
       scheduler_params={"patience": 5, "factor": 0.5}
    )

    my_vae_config = model_config = VAEConfig(
       input_dim=(1, 28, 28),
       latent_dim=10
    )

    my_vae_model = VAE(
       model_config=my_vae_config
    )

    # Add the ray callback to the callback list
    callbacks = [RayCallback()]

    trainer = BaseTrainer(
       my_vae_model,
       train_dataset,
       eval_dataset,
       my_training_config,
       callbacks=callbacks # pass the callbacks to the trainer
    )

    trainer.train() # launch the training
  1. You can launch the ray tuning
search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
}

tuner = tune.Tuner(
    train_ray,
    tune_config=tune.TuneConfig(
        num_samples=20,
        scheduler=ASHAScheduler(metric="eval_epoch_loss", mode="min"),
    ),
    param_space=search_space,
)

results = tuner.fit()

I have opened #87 since some minor changes should be added to the current implementation of the BaseTrainer to be able to read the metrics at the end of each epoch. Let me know if this is the behavior you are expecting :) In particular, you can look at this script example.

Do not hesitate, if you have any questions.

Best,

Clément

@clementchadebec clementchadebec added question Further information is requested feature request labels Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants