Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint's size is increasing everytime. #134

Closed
IvoryTower800 opened this issue Mar 19, 2024 · 3 comments
Closed

checkpoint's size is increasing everytime. #134

IvoryTower800 opened this issue Mar 19, 2024 · 3 comments

Comments

@IvoryTower800
Copy link

Describe the bug
Hi, when I'm finetuning gemma. the checkpoint size was a fixed value at the begining. Then it became bigger and bigger. Finally, when it reached 5.99GB, it can still continue finetuning, but cannot save any new checkpoint and raised error ValueError: unicode string is too large.

To Reproduce
Steps to reproduce the behavior
image

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 185
    177 dataset_train = load_from_disk('/kaggle/input/sadlfkjaslkgma8192')
    179 trainer = CausalLanguageModelTrainer(
    180     train_arguments,
    181     dataset_train,
    182     checkpoint_path='/root/' + ckpt_name
    183 )
--> 185 output = trainer.train()
    186 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
    188 api.upload_file(
    189     path_or_fileobj=output.checkpoint_path,
    190     path_in_repo=output.checkpoint_path.split('/')[-1],
    191     repo_id="ivt1993/writer_2b_gemma",
    192     repo_type="model"
    193 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:703, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    693     shard_fns, gather_fns = make_shard_and_gather_fns(
    694         match_partition_rules(
    695             self.config.get_partition_rules(
   (...)
    700         dtype_specs=self.dtype
    701     )  # You have to re-init the new shard and gather functions in order to be able to skip LoRA weight
    702     # crashing errors and saving errors
--> 703     filename = self._save_state(
    704         state=sharded_state,
    705         gather_fns=gather_fns
    706     )
    707     checkpoint_path = f"{str(self.arguments.get_path())}/{filename}"
    709 if self.arguments.do_eval:

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:494, in CausalLanguageModelTrainer._save_state(self, state, gather_fns, milestone)
    492 filename += ".easy"
    493 termcolor.cprint(f"Saving Model {filename}.", color="cyan", force_color=True)
--> 494 state.save_state(
    495     filename=filename,
    496     checkpoint_dir=checkpoint_dir,
    497     gather_fns=gather_fns,
    498     float_dtype=self.dtype,
    499     verbose=self.arguments.verbose,
    500     save_optimizer=self.arguments.save_optimizer_state,
    501 )
    502 return filename

File /usr/local/lib/python3.10/site-packages/EasyDel/etils/easystate.py:372, in EasyDelState.save_state(self, filename, save_optimizer, checkpoint_dir, verbose, gather_fns, float_dtype)
    361     state = self.replace(
    362         opt_state=None
    363     )
    364 state = state.replace(
    365     module_config_args={
    366         k: v for k, v in state.module_config.__dict__.items() if
   (...)
    370     }
    371 )
--> 372 fjformer.CheckpointManager.save_state_to_file(
    373     state=state,
    374     path=os.path.join(checkpoint_dir, filename) if checkpoint_dir is not None else filename,
    375     verbose=verbose,
    376     gather_fns=gather_fns,
    377     float_dtype=float_dtype,
    378 )

File /usr/local/lib/python3.10/site-packages/fjformer/checkpoint/streamer.py:112, in CheckpointManager.save_state_to_file(state, path, gather_fns, float_dtype, verbose, mismatch_allowed)
    110 pbar.set_postfix(gather_functions_mismatch=gather_functions_mismatch)
    111 value = get_dtype(value, float_dtype)
--> 112 stream.write(packer.pack((key, to_bytes(value))))

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:294, in msgpack._cmsgpack.Packer.pack()

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:300, in msgpack._cmsgpack.Packer.pack()

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:297, in msgpack._cmsgpack.Packer.pack()

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:264, in msgpack._cmsgpack.Packer._pack()

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:264, in msgpack._cmsgpack.Packer._pack()

File /usr/local/lib/python3.10/site-packages/msgpack/_packer.pyx:211, in msgpack._cmsgpack.Packer._pack()

ValueError: unicode string is too large
@erfanzar
Copy link
Owner

is it possible to share weights and state with me? so i can debug that and fix issue, anyway that's the first time i see an issue like that i have trained a Mixtral model for 8 days and each 10B tokens i would save that and at the end i got something like 100~ checkpoint, but i didn't get this error.

@IvoryTower800
Copy link
Author

Sure. I sent you them by email. Please check it. Thank you so much!

@erfanzar
Copy link
Owner

this issue might be fixed do to recent changes and bug fixes in past days in fjformer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants