@@ -150,6 +150,8 @@ def ddim_sampling(self, cond, shape,
150
150
assert x0 is not None
151
151
img_orig = self .model .q_sample (x0 , ts ) # TODO: deterministic forward pass?
152
152
img = img_orig * mask + (1. - mask ) * img
153
+ #tmp_mask = (mask > (1 - (step / 1000))) * 1
154
+ #img = img_orig_with_noise * tmp_mask + (1. - tmp_mask) * img
153
155
154
156
outs = self .p_sample_ddim (img , cond , ts , index = index , use_original_steps = ddim_use_original_steps ,
155
157
quantize_denoised = quantize_denoised , temperature = temperature ,
@@ -159,6 +161,7 @@ def ddim_sampling(self, cond, shape,
159
161
unconditional_conditioning = unconditional_conditioning ,
160
162
dynamic_threshold = dynamic_threshold )
161
163
img , pred_x0 = outs
164
+ #if callback: callback(i)
162
165
if callback :
163
166
img = callback (i , img , pred_x0 )
164
167
if img_callback : img_callback (pred_x0 , i )
@@ -292,7 +295,7 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
292
295
293
296
@torch .no_grad ()
294
297
def decode (self , x_latent , cond , t_start , unconditional_guidance_scale = 1.0 , unconditional_conditioning = None ,
295
- use_original_steps = False ):
298
+ use_original_steps = False , img_callback = None ):
296
299
297
300
timesteps = np .arange (self .ddpm_num_timesteps ) if use_original_steps else self .ddim_timesteps
298
301
timesteps = timesteps [:t_start ]
@@ -309,4 +312,7 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco
309
312
x_dec , _ = self .p_sample_ddim (x_dec , cond , ts , index = index , use_original_steps = use_original_steps ,
310
313
unconditional_guidance_scale = unconditional_guidance_scale ,
311
314
unconditional_conditioning = unconditional_conditioning )
315
+
316
+ if img_callback : img_callback (x_dec , i )
317
+
312
318
return x_dec
0 commit comments