@@ -127,15 +127,15 @@ def ddpm_denoise_sample(
127
127
else :
128
128
x_t = np .random .normal (size = orig_x .shape )
129
129
130
- x_t = nnet .tensor (x_t , requires_grad = False , device = device )
130
+ x_t = nnet .tensor (x_t , device = device )
131
131
x_ts = []
132
132
for t in tqdm (
133
133
reversed (range (0 , self .timesteps )),
134
134
desc = "ddpm denoisinig samples" ,
135
135
total = self .timesteps ,
136
136
):
137
137
noise = (
138
- nnet .tensor (np .random .normal (size = x_t .shape ), requires_grad = False , device = device )
138
+ nnet .tensor (np .random .normal (size = x_t .shape ), device = device )
139
139
if t > 1
140
140
else 0
141
141
)
@@ -152,7 +152,6 @@ def ddpm_denoise_sample(
152
152
if mask is not None :
153
153
orig_x_noise = nnet .tensor (
154
154
np .random .normal (size = orig_x .shape ),
155
- requires_grad = False ,
156
155
device = device ,
157
156
)
158
157
@@ -196,32 +195,40 @@ def ddim_denoise_sample(
196
195
else :
197
196
x_t = np .random .normal (size = orig_x .shape )
198
197
198
+ x_t = nnet .tensor (x_t , device = device )
199
199
x_ts = []
200
200
for t in tqdm (
201
201
reversed (range (1 , self .timesteps )[:perform_steps ]),
202
202
desc = "ddim denoisinig samples" ,
203
203
total = perform_steps ,
204
204
):
205
- noise = np .random .normal (size = x_t .shape ) if t > 1 else 0
206
- eps = self .model .forward (x_t , np .array ([t ]) / self .timesteps , training = False ).reshape (
207
- x_t .shape
205
+ noise = (
206
+ nnet .tensor (np .random .normal (size = x_t .shape ), device = device )
207
+ if t > 1
208
+ else 0
208
209
)
210
+ eps = self .model .forward (x_t , np .array ([t ]) / self .timesteps ).reshape (
211
+ x_t .shape
212
+ ).detach ()
209
213
210
- x0_t = (x_t - eps * np .sqrt (1 - self .alphas_cumprod [t ])) / np .sqrt (
214
+ x0_t = (x_t - eps * nnet .sqrt (1 - self .alphas_cumprod [t ])) / nnet .sqrt (
211
215
self .alphas_cumprod [t ]
212
216
)
213
217
214
- sigma = eta * np .sqrt (
218
+ sigma = eta * nnet .sqrt (
215
219
(1 - self .alphas_cumprod [t - 1 ])
216
220
/ (1 - self .alphas_cumprod [t ])
217
221
* (1 - self .alphas_cumprod [t ] / self .alphas_cumprod [t - 1 ])
218
222
)
219
- c = np .sqrt ((1 - self .alphas_cumprod [t - 1 ]) - sigma ** 2 )
223
+ c = nnet .sqrt ((1 - self .alphas_cumprod [t - 1 ]) - sigma ** 2 )
220
224
221
- x_t = np .sqrt (self .alphas_cumprod [t - 1 ]) * x0_t - c * eps + sigma * noise
225
+ x_t = nnet .sqrt (self .alphas_cumprod [t - 1 ]) * x0_t - c * eps + sigma * noise
222
226
223
227
if mask is not None :
224
- orig_x_noise = np .random .normal (size = orig_x .shape )
228
+ orig_x_noise = nnet .tensor (
229
+ np .random .normal (size = orig_x .shape ),
230
+ device = device ,
231
+ )
225
232
226
233
orig_x_t = (
227
234
self .sqrt_alphas_cumprod [t ] * orig_x
@@ -230,9 +237,9 @@ def ddim_denoise_sample(
230
237
x_t = orig_x_t * mask + x_t * (1 - mask )
231
238
232
239
if t % states_step_size == 0 :
233
- x_ts .append (x_t )
240
+ x_ts .append (x_t . cpu (). detach (). numpy () )
234
241
235
- return x_t , x_ts
242
+ return x_t . to ( "cpu" ). detach (). numpy () , x_ts
236
243
237
244
def get_images_set (
238
245
self ,
0 commit comments