Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
liutaocode committed Aug 10, 2024
1 parent 67efd2c commit 7256000
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion code/diffusion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""

if th.backends.mps.is_available():
if hasattr(th.backends, 'mps') and th.backends.mps.is_available():
arr = arr.astype(np.float32)
# Convert the numpy array to a tensor and then move to the device
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps]
Expand Down

0 comments on commit 7256000

Please sign in to comment.