You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed this in the original repo as well. I can fix by sending the embedding tensors to cpu and then back to MPS at the end, but I'm not sure if this is the best way to handle it and may create issues with the output quality.
My hack:
def call(self, ids: torch.Tensor):
self.freqs_cis = [freqs_cis.to("cpu") for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64).to("cpu")
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index).to(torch.complex32).to(ids.device))
return torch.cat(result, dim=-1)
The text was updated successfully, but these errors were encountered:
def __call__(self, ids: torch.Tensor):
# Move freqs_cis to the same device as ids
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
# Extract the real and imaginary parts of the complex tensor
freqs_cis_real = self.freqs_cis[i].real
freqs_cis_imag = self.freqs_cis[i].imag
# Repeat the indices to match the dimensions of freqs_cis
index = ids[:, :, i:i+1].repeat(1, 1, freqs_cis_real.shape[-1]).to(torch.int64)
# Gather the real and imaginary parts separately
gathered_real = torch.gather(freqs_cis_real.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)
gathered_imag = torch.gather(freqs_cis_imag.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)
# Combine the real and imaginary parts back into a complex tensor
result.append(torch.complex(gathered_real, gathered_imag))
# Concatenate the results along the last dimension
return torch.cat(result, dim=-1)
Expected Behavior
The Lumina model begins rendering.
Actual Behavior
The rendering fails as the repeat/gather operators are not supported for complex numbers on the MPS backend.
Steps to Reproduce
Download the Lumina model and run on any Mac with an M1-3 chip.
Debug Logs
Other
I noticed this in the original repo as well. I can fix by sending the embedding tensors to cpu and then back to MPS at the end, but I'm not sure if this is the best way to handle it and may create issues with the output quality.
My hack:
def call(self, ids: torch.Tensor):
self.freqs_cis = [freqs_cis.to("cpu") for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64).to("cpu")
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index).to(torch.complex32).to(ids.device))
return torch.cat(result, dim=-1)
The text was updated successfully, but these errors were encountered: