From 48361f6671a4950545f5582e5bc9c317af78a27b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 25 Aug 2022 10:31:54 -0700 Subject: [PATCH] update with the best type of downsample --- README.md | 10 ++++++++++ imagen_pytorch/imagen_pytorch.py | 7 ++++++- imagen_pytorch/version.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0bd4642..ffd3dd0 100644 --- a/README.md +++ b/README.md @@ -768,3 +768,13 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo primaryClass = {cs.CV} } ``` + +```bibtex +@article{Sunkara2022NoMS, + title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects}, + author = {Raja Sunkara and Tie Luo}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2208.03641} +} +``` diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index d5477fb..5460abe 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -583,8 +583,13 @@ def forward(self, x): return self.net(x) def Downsample(dim, dim_out = None): + # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample + # named SP-conv in the paper, but basically a pixel unshuffle dim_out = default(dim_out, dim) - return nn.Conv2d(dim, dim_out, 4, 2, 1) + return nn.Sequential( + Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), + nn.Conv2d(dim * 4, dim_out, 1) + ) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index ff987d2..da77e85 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.1' +__version__ = '1.11.0'