-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hyperparameter Tuning of iTransformer using Optuna #1009
Comments
Please figure out the issue. |
Hi, Can you provide a standalone piece of code that I can run that reproduces the issue? That said, the RuntimeError seems to provide a clue. I'd remove the Some generic tips based on the above code:
|
Yeah, when I remove early_stop_patience_steps parameter the code runs properly but its of no use as for every trial model will run for all epochs.Please find the code here: |
You have to set @elephaint we may want to do this automatically if |
Yes that makes a lot of sense |
refit_with_val = True worked. |
def config_itransformer(trial):
return {
# "h" : 14,
"input_size" : trial.suggest_int("input_size", 60, 360, step = 60),
"val_check_steps": 10,
Cell In[56], line 2
1 fcst = NeuralForecast(models=[model], freq='B')
----> 2 fcst.fit(df=train_data, val_size=14)
File /opt/conda/lib/python3.10/site-packages/neuralforecast/core.py:462, in NeuralForecast.fit(self, df, static_df, val_size, sort_df, use_init_models, verbose, id_col, time_col, target_col, distributed_config)
459 self._reset_models()
461 for i, model in enumerate(self.models):
--> 462 self.models[i] = model.fit(
463 self.dataset, val_size=val_size, distributed_config=distributed_config
464 )
466 self._fitted = True
File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_auto.py:424, in BaseAuto.fit(self, dataset, val_size, test_size, random_seed, distributed_config)
412 results = self._optuna_tune_model(
413 cls_model=self.cls_model,
414 dataset=dataset,
(...)
421 distributed_config=distributed_config,
422 )
423 best_config = results.best_trial.user_attrs["ALL_PARAMS"]
--> 424 self.model = self._fit_model(
425 cls_model=self.cls_model,
426 config=best_config,
427 dataset=dataset,
428 val_size=val_size * self.refit_with_val,
429 test_size=test_size,
430 distributed_config=distributed_config,
431 )
432 self.results = results
434 # Added attributes for compatibility with NeuralForecast core
File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_auto.py:357, in BaseAuto._fit_model(self, cls_model, config, dataset, val_size, test_size, distributed_config)
353 def _fit_model(
354 self, cls_model, config, dataset, val_size, test_size, distributed_config=None
355 ):
356 model = cls_model(**config)
--> 357 model = model.fit(
358 dataset,
359 val_size=val_size,
360 test_size=test_size,
361 distributed_config=distributed_config,
362 )
363 return model
File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_multivariate.py:537, in BaseMultivariate.fit(self, dataset, val_size, test_size, random_seed, distributed_config)
533 if distributed_config is not None:
534 raise ValueError(
535 "multivariate models cannot be trained using distributed data parallel."
536 )
--> 537 return self._fit(
538 dataset=dataset,
539 batch_size=self.n_series,
540 valid_batch_size=self.n_series,
541 val_size=val_size,
542 test_size=test_size,
543 random_seed=random_seed,
544 shuffle_train=False,
545 distributed_config=None,
546 )
File /opt/conda/lib/python3.10/site-packages/neuralforecast/common/_base_model.py:219, in BaseModel._fit(self, dataset, batch_size, valid_batch_size, val_size, test_size, random_seed, shuffle_train, distributed_config)
217 model = self
218 trainer = pl.Trainer(**model.trainer_kwargs)
--> 219 trainer.fit(model, datamodule=datamodule)
220 model.metrics = trainer.callback_metrics
221 model.dict.pop("_trainer", None)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
982 self._signal_connector.register_signal_handlers()
984 # ----------------------------
985 # RUN THE TRAINER
986 # ----------------------------
--> 987 results = self._run_stage()
989 # ----------------------------
990 # POST-Training CLEAN UP
991 # ----------------------------
992 log.debug(f"{self.class.name}: trainer tearing down")
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self)
1031 self._run_sanity_check()
1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1033 self.fit_loop.run()
1034 return None
1035 raise RuntimeError(f"Unexpected state {self.state}")
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher)
139 try:
140 self.advance(data_fetcher)
--> 141 self.on_advance_end(data_fetcher)
142 self._restarting = False
143 except StopIteration:
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
291 if not self._should_accumulate():
292 # clear gradients to not leave any unused memory during validation
293 call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> 295 self.val_loop.run()
296 self.trainer.training = True
297 self.trainer._logger_connector._first_loop_iter = first_loop_iter
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:182, in _no_grad_context.._decorator(self, *args, **kwargs)
180 context_manager = torch.no_grad
181 with context_manager():
--> 182 return loop_run(self, *args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:142, in _EvaluationLoop.run(self)
140 self._restarting = False
141 self._store_dataloader_outputs()
--> 142 return self.on_run_end()
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:268, in _EvaluationLoop.on_run_end(self)
265 self.trainer._logger_connector.log_eval_end_metrics(all_logged_outputs)
267 # hook
--> 268 self._on_evaluation_end()
270 # enable train mode again
271 self._on_evaluation_model_train()
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:313, in _EvaluationLoop._on_evaluation_end(self, *args, **kwargs)
311 trainer = self.trainer
312 hook_name = "on_test_end" if trainer.testing else "on_validation_end"
--> 313 call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
314 call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs)
315 call._call_strategy_hook(trainer, hook_name, *args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:208, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
206 if callable(fn):
207 with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 208 fn(trainer, trainer.lightning_module, *args, **kwargs)
210 if pl_module:
211 # restore current_fx when nested context
212 pl_module._current_fx_name = prev_fx_name
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:196, in EarlyStopping.on_validation_end(self, trainer, pl_module)
194 if self._check_on_train_epoch_end or self._should_skip_check(trainer):
195 return
--> 196 self._run_early_stopping_check(trainer)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:202, in EarlyStopping._run_early_stopping_check(self, trainer)
199 """Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
200 logs = trainer.callback_metrics
--> 202 if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run
203 logs
204 ): # short circuit if metric not present
205 return
207 current = logs[self.monitor].squeeze()
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/early_stopping.py:153, in EarlyStopping._validate_condition_metric(self, logs)
151 if monitor_val is None:
152 if self.strict:
--> 153 raise RuntimeError(error_msg)
154 if self.verbose > 0:
155 rank_zero_warn(error_msg, category=RuntimeWarning)
RuntimeError: Early stopping conditioned on metric
ptl/val_loss
which is not available. Pass in or modify yourEarlyStopping
callback to use any of the following:train_loss
,train_loss_step
,train_loss_epoch
The text was updated successfully, but these errors were encountered: