Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training process fails with a Jax library related issue #255

Open
randheerDas opened this issue Apr 12, 2024 · 6 comments
Open

Training process fails with a Jax library related issue #255

randheerDas opened this issue Apr 12, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@randheerDas
Copy link

Describe the bug

Training process fails with a Jax library related issue.

This the the python code in the notebook cell, that fails:

!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"

Attached is the screenshot for the error:

Error

Reproduction

Run the training process by issuing the following command:

!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"

Logs

No response

System Info

I am running this on a google colab runtime on a python 3 running on a Google compute engine with a Tesla GPU.

Install details:

!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py
!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py
%pip install -qq git+https://github.com/ShivamShrirao/diffusers
%pip install -q -U --pre triton
%pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers

@randheerDas randheerDas added the bug Something isn't working label Apr 12, 2024
@mahaboobkhan29
Copy link

Any Update ? facing same issue

@The-Ramosian
Copy link

this seems to work:

!pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@JossCamp
Copy link

Google always ends up ruining something in each update, you need to use a specific version:
!pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html

This solves the problem for now.

@mirodil-ml
Copy link

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

@roman19932024
Copy link

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

I have the same problem. Did anyone find a solution?

@mirodil-ml
Copy link

@roman19932024 try to update python version to 3.10.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants