Skip to content

Commit

Permalink
ready for config driven training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 27, 2022
1 parent 939d10c commit 5bf7f1a
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions imagen_pytorch/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@ def default(val, d):
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]

def SingleOrList(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]

# noise schedule

class BetaSchedule(Enum):
cosine = 'cosine'
linear = 'linear'

class AllowExtraBaseModel(BaseModel):
class Config:
extra = "allow"

# imagen pydantic classes

class UnetConfig(BaseModel):
class UnetConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
Expand All @@ -34,14 +41,13 @@ class UnetConfig(BaseModel):
attn_dim_head: int = 32
attn_heads: int = 16

class Config:
extra = "allow"

class ImagenConfig(BaseModel):
class ImagenConfig(AllowExtraBaseModel):
unets: ListOrTuple(UnetConfig)
image_sizes: ListOrTuple(int)
timesteps: Union[int, ListOrTuple(int)] = 1000
beta_schedules: Union[BetaSchedule, ListOrTuple(BetaSchedule)] = 'cosine'
timesteps: SingleOrList(int) = 1000
beta_schedules: SingleOrList(BetaSchedule) = 'cosine'
warmup_steps: SingleOrList(int) = None
cosine_decay_max_steps: SingleOrList(int) = None
text_encoder_name: str = DEFAULT_T5_NAME
channels: int = 3
loss_type: str = 'l2'
Expand Down

0 comments on commit 5bf7f1a

Please sign in to comment.