Skip to content

Commit 8be6a65

Browse files
committed
quick diffusion ddim sampler fix
1 parent c5f6f8d commit 8be6a65

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

examples/ddpm.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,15 @@ def ddpm_denoise_sample(
127127
else:
128128
x_t = np.random.normal(size=orig_x.shape)
129129

130-
x_t = nnet.tensor(x_t, requires_grad=False, device=device)
130+
x_t = nnet.tensor(x_t, device=device)
131131
x_ts = []
132132
for t in tqdm(
133133
reversed(range(0, self.timesteps)),
134134
desc="ddpm denoisinig samples",
135135
total=self.timesteps,
136136
):
137137
noise = (
138-
nnet.tensor(np.random.normal(size=x_t.shape), requires_grad=False, device=device)
138+
nnet.tensor(np.random.normal(size=x_t.shape), device=device)
139139
if t > 1
140140
else 0
141141
)
@@ -152,7 +152,6 @@ def ddpm_denoise_sample(
152152
if mask is not None:
153153
orig_x_noise = nnet.tensor(
154154
np.random.normal(size=orig_x.shape),
155-
requires_grad=False,
156155
device=device,
157156
)
158157

@@ -196,32 +195,40 @@ def ddim_denoise_sample(
196195
else:
197196
x_t = np.random.normal(size=orig_x.shape)
198197

198+
x_t = nnet.tensor(x_t, device=device)
199199
x_ts = []
200200
for t in tqdm(
201201
reversed(range(1, self.timesteps)[:perform_steps]),
202202
desc="ddim denoisinig samples",
203203
total=perform_steps,
204204
):
205-
noise = np.random.normal(size=x_t.shape) if t > 1 else 0
206-
eps = self.model.forward(x_t, np.array([t]) / self.timesteps, training=False).reshape(
207-
x_t.shape
205+
noise = (
206+
nnet.tensor(np.random.normal(size=x_t.shape), device=device)
207+
if t > 1
208+
else 0
208209
)
210+
eps = self.model.forward(x_t, np.array([t]) / self.timesteps).reshape(
211+
x_t.shape
212+
).detach()
209213

210-
x0_t = (x_t - eps * np.sqrt(1 - self.alphas_cumprod[t])) / np.sqrt(
214+
x0_t = (x_t - eps * nnet.sqrt(1 - self.alphas_cumprod[t])) / nnet.sqrt(
211215
self.alphas_cumprod[t]
212216
)
213217

214-
sigma = eta * np.sqrt(
218+
sigma = eta * nnet.sqrt(
215219
(1 - self.alphas_cumprod[t - 1])
216220
/ (1 - self.alphas_cumprod[t])
217221
* (1 - self.alphas_cumprod[t] / self.alphas_cumprod[t - 1])
218222
)
219-
c = np.sqrt((1 - self.alphas_cumprod[t - 1]) - sigma**2)
223+
c = nnet.sqrt((1 - self.alphas_cumprod[t - 1]) - sigma**2)
220224

221-
x_t = np.sqrt(self.alphas_cumprod[t - 1]) * x0_t - c * eps + sigma * noise
225+
x_t = nnet.sqrt(self.alphas_cumprod[t - 1]) * x0_t - c * eps + sigma * noise
222226

223227
if mask is not None:
224-
orig_x_noise = np.random.normal(size=orig_x.shape)
228+
orig_x_noise = nnet.tensor(
229+
np.random.normal(size=orig_x.shape),
230+
device=device,
231+
)
225232

226233
orig_x_t = (
227234
self.sqrt_alphas_cumprod[t] * orig_x
@@ -230,9 +237,9 @@ def ddim_denoise_sample(
230237
x_t = orig_x_t * mask + x_t * (1 - mask)
231238

232239
if t % states_step_size == 0:
233-
x_ts.append(x_t)
240+
x_ts.append(x_t.cpu().detach().numpy())
234241

235-
return x_t, x_ts
242+
return x_t.to("cpu").detach().numpy(), x_ts
236243

237244
def get_images_set(
238245
self,

0 commit comments

Comments
 (0)