Skip to content

Commit efa83c6

Browse files
update
1 parent 820ed65 commit efa83c6

File tree

3 files changed

+4
-9
lines changed

3 files changed

+4
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ https://user-images.githubusercontent.com/27779063/206553305-e01009f7-3131-4a6b-
1515
# Installation
1616
We recommend using [`conda`](https://www.anaconda.com/products/distribution) to install the required python packages. You might need to change the `cudatoolkit` version to match with your GPU driver.
1717
```
18-
conda create -n sdf python=3.8 -y && conda activate sdf
18+
conda create -n sdfusion python=3.8 -y && conda activate sdfusion
1919
conda install pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=11.3 -c pytorch -c conda-forge -y
2020
conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y
2121
conda install pytorch3d -c pytorch3d

datasets/snet_mm2shape_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def initialize(self, opt, phase='train', cat='all', res=64):
6262
render_img_list = [os.path.join(render_img_dir, f) for f in os.listdir(render_img_dir) if '.png' in f]
6363

6464
if not os.path.exists(sdf_path):
65-
# import pdb; pdb.set_trace()
6665
continue
6766

6867
self.model_list.append(sdf_path)

models/sdfusion_mm_model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def forward(self):
411411
# check: ddpm.py, log_images(). line 1317~1327
412412
@torch.no_grad()
413413
# def inference(self, data, sample=True, ddim_steps=None, ddim_eta=0., quantize_denoised=True, infer_all=False):
414-
def naive_inference(self, data, ddim_steps=None, ddim_eta=0., uc_scale=None,
414+
def inference(self, data, ddim_steps=None, ddim_eta=0., uc_scale=None,
415415
infer_all=False, max_sample=16):
416416

417417
self.switch_eval()
@@ -444,10 +444,6 @@ def naive_inference(self, data, ddim_steps=None, ddim_eta=0., uc_scale=None,
444444
shape = self.z_shape
445445

446446
# get noise, denoise, and decode with vqvae
447-
uc = self.cond_model(self.uc_img).float() # img shape
448-
c_img = self.cond_model(self.img).float()
449-
B = c_img.shape[0]
450-
shape = self.z_shape
451447
samples, intermediates = self.ddim_sampler.sample(S=ddim_steps,
452448
batch_size=B,
453449
shape=shape,
@@ -465,7 +461,7 @@ def naive_inference(self, data, ddim_steps=None, ddim_eta=0., uc_scale=None,
465461
# txt_scale=1.0, img_scale=1.0, mask_mode='1', mask_x=False,
466462
# mm_cls_free=False,
467463
# ):
468-
def inference(self, data, mask_mode=None, ddim_steps=None, ddim_eta=0., uc_scale=None,
464+
def mm_inference(self, data, mask_mode=None, ddim_steps=None, ddim_eta=0., uc_scale=None,
469465
txt_scale=1.0, img_scale=1.0, mm_cls_free=False, infer_all=False, max_sample=16):
470466

471467
self.switch_eval()
@@ -639,7 +635,7 @@ def get_current_visuals(self):
639635
b, c, h, w = self.img_gt.shape
640636
img_shape = (3, h, w)
641637
# write text as img
642-
self.img_text = self.get_img_text(self.txt, bs=b, img_shape=img_shape)
638+
self.img_text = self.write_text_on_img(self.txt, bs=b, img_shape=img_shape)
643639
self.img_text = rearrange(torch.from_numpy(self.img_text), 'b h w c -> b c h w')
644640

645641
vis_tensor_names = [

0 commit comments

Comments
 (0)