Skip to content

Commit

Permalink
Got it running on MBP
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Nov 1, 2022
1 parent 32c5e38 commit 63ec96a
Show file tree
Hide file tree
Showing 18 changed files with 211 additions and 217 deletions.
Binary file modified __pycache__/main.cpython-39.pyc
Binary file not shown.
39 changes: 19 additions & 20 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,26 @@ channels:
dependencies:
- python=3.8.10
- pip=20.3
- cudatoolkit=11.3
- pytorch=1.10.2
- torchvision=0.11.3
- numpy=1.22.3
- pip:
- albumentations==1.1.0
- opencv-python==4.2.0.34
- pudb==2019.2
- imageio==2.14.1
- imageio-ffmpeg==0.4.7
- pytorch-lightning==1.5.9
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- setuptools==59.5.0
- pillow==9.0.1
- einops==0.4.1
- torch-fidelity==0.3.0
- transformers==4.18.0
- torchmetrics==0.6.0
- kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
- albumentations==1.1.0
- opencv-python==4.2.0.34
- pudb==2019.2
- imageio==2.14.1
- imageio-ffmpeg==0.4.7
- pytorch-lightning==1.5.9
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- setuptools==59.5.0
- pillow==9.0.1
- einops==0.4.1
- torch-fidelity==0.3.0
- transformers==4.18.0
- torchmetrics==0.6.0
- kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
Binary file modified ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc
Binary file not shown.
Binary file modified ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device("mps"):
attr = attr.to(torch.device("mps"))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
2 changes: 1 addition & 1 deletion ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def make_cond_schedule(self, ):

@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
Expand Down
4 changes: 2 additions & 2 deletions ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device("mps"):
attr = attr.to(torch.device("mps"))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
Binary file modified ldm/modules/encoders/__pycache__/modules.cpython-39.pyc
Binary file not shown.
12 changes: 6 additions & 6 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="mps"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
Expand All @@ -73,7 +73,7 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device="mps", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Expand Down Expand Up @@ -101,7 +101,7 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device="mps",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
Expand Down Expand Up @@ -156,7 +156,7 @@ def encode(self, x):

class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -328,7 +328,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(self, version='ViT-L/14', device="mps", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
Expand Down Expand Up @@ -364,7 +364,7 @@ def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device='mps',
antialias=False,
):
super().__init__()
Expand Down
12 changes: 6 additions & 6 deletions ldm/modules/encoders/modules_bak.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="mps"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
Expand All @@ -73,7 +73,7 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device="mps", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Expand Down Expand Up @@ -101,7 +101,7 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device="mps",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
Expand Down Expand Up @@ -156,7 +156,7 @@ def encode(self, x):

class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device="mps", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -428,7 +428,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(self, version='ViT-L/14', device="mps", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
Expand Down Expand Up @@ -464,7 +464,7 @@ def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device='mps',
antialias=False,
):
super().__init__()
Expand Down
15 changes: 5 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __init__(self, batch_size, train=None, reg = None, validation=None, test=Non
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.num_workers = 10
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
Expand Down Expand Up @@ -588,12 +588,7 @@ def on_train_epoch_start(self, trainer, pl_module):
trainer_config["strategy"] = "ddp"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if torch.cuda.is_available():
device = torch.device("cuda")
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
elif torch.backends.mps.is_built():
if torch.backends.mps.is_built():
device = torch.device("mps")
trainer_config["accelerator"] = "mps"
print(f"Running on MPS")
Expand Down Expand Up @@ -749,8 +744,6 @@ def on_train_epoch_start(self, trainer, pl_module):
config.data.params.train.params.data_root = opt.data_root
config.data.params.reg.params.data_root = opt.reg_data_root
config.data.params.validation.params.data_root = opt.data_root
data = instantiate_from_config(config.data)

data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
Expand Down Expand Up @@ -809,7 +802,9 @@ def divein(*args, **kwargs):
# run
if opt.train:
try:
trainer.fit(model, data)
print("type of data: ", type(data))
trainer.fit(model, data)
# trainer.fit(model, datamodule=DataModuleFromConfig(10))
except Exception:
melk()
raise
Expand Down
4 changes: 2 additions & 2 deletions merge_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def get_bert_token_for_string(tokenizer, string):
args = parser.parse_args()

if args.stable_diffusion:
embedder = FrozenCLIPEmbedder().cuda()
embedder = FrozenCLIPEmbedder().to(torch.device("mps"))
else:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
embedder = BERTEmbedder(n_embed=1280, n_layer=32).to(torch.device("mps"))

EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])

Expand Down
3 changes: 1 addition & 2 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

# model.cuda()
mps_device = torch.device("mps")
model.to(mps_device)
model.eval()
Expand Down Expand Up @@ -71,7 +70,7 @@ def load_model_from_config(config, ckpt, verbose=False):
model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
model.embedding_manager.load(opt.embedding_path)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
device = torch.device("mps")
model = model.to(device)

evaluator = LDMCLIPEvaluator(device)
Expand Down
2 changes: 1 addition & 1 deletion scripts/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def make_batch(image, mask, device):
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
device = torch.device("mps")
model = model.to(device)
sampler = DDIMSampler(model)

Expand Down
Loading

0 comments on commit 63ec96a

Please sign in to comment.