diff --git a/README.md b/README.md index e3261799..8c0e8278 100644 --- a/README.md +++ b/README.md @@ -751,4 +751,14 @@ $ python generate.py --chinese --text '追老鼠的猫' } ``` +```bibtex +@article{Liu2023BridgingDA, + title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond}, + author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao}, + journal = {ArXiv}, + year = {2023}, + volume = {abs/2304.08612} +} +``` + *Those who do not want to imitate anything, produce nothing.* - Dali diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 1998a9fe..45becc33 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -111,6 +111,7 @@ def __init__( smooth_l1_loss = False, temperature = 0.9, straight_through = False, + reinmax = False, kl_div_loss_weight = 0., normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1)) ): @@ -125,6 +126,8 @@ def __init__( self.num_layers = num_layers self.temperature = temperature self.straight_through = straight_through + self.reinmax = reinmax + self.codebook = nn.Embedding(num_tokens, codebook_dim) hdim = hidden_dim @@ -227,8 +230,20 @@ def forward( return logits # return logits for getting hard image indices for DALL-E training temp = default(temp, self.temperature) - soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through) - sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight) + + one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through) + + if self.straight_through and self.reinmax: + # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # algorithm 2 + one_hot = one_hot.detach() + π0 = logits.softmax(dim = 1) + π1 = (one_hot + (logits / temp).softmax(dim = 1)) / 2 + π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1) + π2 = 2 * π1 - 0.5 * π0 + one_hot = π2 - π2.detach() + one_hot + + sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight) out = self.decoder(sampled) if not return_loss: diff --git a/dalle_pytorch/version.py b/dalle_pytorch/version.py index d07785c5..f3df7f04 100644 --- a/dalle_pytorch/version.py +++ b/dalle_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.4' +__version__ = '1.6.5'