Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable_checkpointing = False results in MisconfigurationException #3203

Open
mraapshockwavemedical opened this issue Jul 11, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@mraapshockwavemedical
Copy link

mraapshockwavemedical commented Jul 11, 2024

Description

I would like to train a TemporalFusionTransformerEstimator and set trainer_kwargs["enable_checkpointing"] = False. However, on line 196 in torch\model\estimator.py a ModelCheckpoint is created nevertheless and added to the list of callbacks on line 204. This results in an error: MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with enable_checkpointing=False but found ModelCheckpoint in callbacks list.

I need to disable checkpoints, because I would like to run this on Snowflake which does not allow writing to a filesystem. Workaround ideas for now would be highly appreciated.

Edit
Snowflake would allow me to write to /tmp/checkpoints, but it seems to be impossible to set the dir_path of the checkpoint created in estimator.py on lines 195-198:

        monitor = "train_loss" if validation_data is None else "val_loss"
        checkpoint = pl.callbacks.ModelCheckpoint(
            monitor=monitor, mode="min", verbose=True
        )

To Reproduce

import pandas as pd
from gluonts.torch import TemporalFusionTransformerEstimator
from gluonts.dataset.pandas import PandasDataset

data = {
    "item_id": [1, 1, 1],
    "ts": ['2024-01-01', '2024-02-01', '2024-03-01'],
    "target": [1, 2, 3]
}
ds = PandasDataset.from_long_dataframe(pd.DataFrame(data), target="target", item_id='item_id', timestamp='ts')
trainer_kwargs = {'enable_checkpointing': False}
estimator = TemporalFusionTransformerEstimator(freq='M', prediction_length=1, trainer_kwargs=trainer_kwargs)
predictor = estimator.train(training_data=ds)

Error message or code output

Traceback (most recent call last):
  File "C:\projects\test\venv\test.py", line 15, in <module>
    predictor = estimator.train(training_data=ds)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 246, in train
    return self.train_model(
           ^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 201, in train_model
    trainer = pl.Trainer(
              ^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\utilities\argparse.py", line 70, in insert_env_defaults
    return fn(self, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\trainer.py", line 431, in __init__
    self._callback_connector.on_trainer_init(
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 66, in on_trainer_init
    self._configure_checkpoint_callbacks(enable_checkpointing)
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 88, in _configure_checkpoint_callbacks
    raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with `enable_checkpointing=False` but found `ModelCheckpoint` in callbacks list.

Environment

  • Operating system: Windows
  • Python version: 3.12
  • GluonTS version: 0.15.1
@mraapshockwavemedical mraapshockwavemedical added the bug Something isn't working label Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant