You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While following through the examples, I started having a problem.
from dalle_pytorch import DiscreteVAE, DALLE
vae = DiscreteVAE(
image_size = 256,
num_layers = 3,
num_tokens = 8192,
codebook_dim = 1024,
hidden_dim = 64,
num_resnet_blocks = 1,
temperature = 0.9
)
dalle = DALLE(
dim = 1024,
vae = vae,
num_text_tokens = 10000,
text_seq_len = 256,
depth = 12,
heads = 16,
dim_head = 64,
attn_dropout = 0.1,
ff_dropout = 0.1
)
text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
loss = dalle(
text,
images,
return_loss = True,
null_cond_prob = 0.2 # firstly, set this to the probability of dropping out the condition, 20% is recommended as a default
)
loss.backward()
# do the above for a long time with a lot of data ... then
images = dalle.generate_images(
text,
cond_scale = 3. # secondly, set this to a value greater than 1 to increase the conditioning beyond average
)
images.shape # (4, 3, 256, 256)
While going through this example I noticed that the images is a tensor so i tried to convert it to an image, but i don't seem to be able to do it. Any ideas?
The text was updated successfully, but these errors were encountered:
While following through the examples, I started having a problem.
While going through this example I noticed that the images is a tensor so i tried to convert it to an image, but i don't seem to be able to do it. Any ideas?
The text was updated successfully, but these errors were encountered: