Skip to content

Commit 9c8dfc0

Browse files
committed
Also force tiled VAE encode where possible #1658
1 parent 6f1ab28 commit 9c8dfc0

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

ai_diffusion/comfy_workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,9 @@ def vae_encode(self, vae: Output, image: Output):
751751
def vae_encode_inpaint(self, vae: Output, image: Output, mask: Output):
752752
return self.add("VAEEncodeForInpaint", 1, vae=vae, pixels=image, mask=mask, grow_mask_by=0)
753753

754+
def vae_encode_tiled(self, vae: Output, image: Output):
755+
return self.add("VAEEncodeTiled", 1, vae=vae, pixels=image, tile_size=512, overlap=64)
756+
754757
def vae_decode(self, vae: Output, latent_image: Output):
755758
return self.add("VAEDecode", 1, vae=vae, samples=latent_image)
756759

ai_diffusion/workflow.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
156156
return model, Clip(clip, arch), vae
157157

158158

159+
def vae_encode(w: ComfyWorkflow, vae: Output, image: Output, tiled: bool):
160+
if tiled:
161+
return w.vae_encode_tiled(vae, image)
162+
return w.vae_encode(vae, image)
163+
164+
159165
def vae_decode(w: ComfyWorkflow, vae: Output, latent: Output, tiled: bool):
160166
if tiled:
161167
return w.vae_decode_tiled(vae, latent)
@@ -678,7 +684,7 @@ def scale_refine_and_decode(
678684
decoded = vae_decode(w, vae, latent, tiled_vae)
679685
upscale = w.upscale_image(upscale_model, decoded)
680686
upscale = w.scale_image(upscale, extent.desired)
681-
latent = w.vae_encode(vae, upscale)
687+
latent = vae_encode(w, vae, upscale, tiled_vae)
682688
params = _sampler_params(sampling, strength=0.4)
683689

684690
positive, negative = encode_text_prompt(w, cond, clip, regions)
@@ -874,7 +880,7 @@ def inpaint(
874880
)
875881
inpaint_model = model
876882
else:
877-
latent = w.vae_encode(vae, in_image)
883+
latent = vae_encode(w, vae, in_image, checkpoint.tiled_vae)
878884
latent = w.set_latent_noise_mask(latent, initial_mask)
879885
inpaint_model = model
880886

@@ -899,7 +905,7 @@ def inpaint(
899905
upscale = ensure_minimum_extent(w, upscale, initial_bounds.extent, 32)
900906
upscale = w.upscale_image(upscale_model, upscale)
901907
upscale = w.scale_image(upscale, upscale_extent.desired)
902-
latent = w.vae_encode(vae, upscale)
908+
latent = vae_encode(w, vae, upscale, checkpoint.tiled_vae)
903909
latent = w.set_latent_noise_mask(latent, upscale_mask)
904910

905911
cond_upscale = cond.copy()
@@ -956,7 +962,7 @@ def refine(
956962
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
957963
in_image = w.load_image(image)
958964
in_image = scale_to_initial(extent, w, in_image, models)
959-
latent = w.vae_encode(vae, in_image)
965+
latent = vae_encode(w, vae, in_image, checkpoint.tiled_vae)
960966
latent = w.batch_latent(latent, misc.batch_count)
961967
positive, negative = encode_text_prompt(w, cond, clip, regions)
962968
model, positive, negative = apply_control(
@@ -1010,7 +1016,7 @@ def refine_region(
10101016
inpaint_patch = w.load_fooocus_inpaint(**models.fooocus_inpaint)
10111017
inpaint_model = w.apply_fooocus_inpaint(model, inpaint_patch, latent_inpaint)
10121018
else:
1013-
latent = w.vae_encode(vae, in_image)
1019+
latent = vae_encode(w, vae, in_image, checkpoint.tiled_vae)
10141020
latent = w.set_latent_noise_mask(latent, initial_mask)
10151021
inpaint_model = model
10161022

@@ -1179,7 +1185,7 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds):
11791185
w, tile_model, positive, negative, control, no_reshape, vae, models
11801186
)
11811187

1182-
latent = w.vae_encode(vae, tile_image)
1188+
latent = vae_encode(w, vae, tile_image, checkpoint.tiled_vae)
11831189
latent = w.set_latent_noise_mask(latent, tile_mask)
11841190
sampler = w.sampler_custom_advanced(
11851191
tile_model, positive, negative, latent, models.arch, **_sampler_params(sampling)

0 commit comments

Comments
 (0)