diff --git a/README.md b/README.md index ea1686b..c55d6b4 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ git clone https://github.com/mh-amani/neural_discrete_reasoning cd neural_discrete_reasoning # [OPTIONAL] create conda environment -conda create -n myenv python=3.11 +conda create -n ndr python=3.11 conda activate ndr # install pytorch according to instructions diff --git a/configs/experiment/pvr.yaml b/configs/experiment/pvr.yaml new file mode 100644 index 0000000..18a8957 --- /dev/null +++ b/configs/experiment/pvr.yaml @@ -0,0 +1,38 @@ +name: "pvr" +run_name: "${model_key}-${discretizer_key}" + +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: mnist + - override /model: mnist + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["mnist", "simple_dense_net"] + +seed: 12345 + +trainer: + min_epochs: 10 + max_epochs: 10 + gradient_clip_val: 0.5 + +model: + optimizer: + lr: 0.002 + + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "mnist" \ No newline at end of file diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml index ece1658..cc7cdcc 100644 --- a/configs/logger/wandb.yaml +++ b/configs/logger/wandb.yaml @@ -7,7 +7,8 @@ wandb: offline: False id: null # pass correct id to resume experiment! anonymous: null # enable anonymous logging - project: "lightning-hydra-template" + project: ${name} + name: ${run_name} log_model: False # upload lightning ckpts prefix: "" # a string to put at the beginning of metric keys # entity: "" # set to name of your wandb team diff --git a/configs/model/transformer_dbn_classifier.yaml b/configs/model/transformer_dbn_classifier.yaml new file mode 100644 index 0000000..2d65901 --- /dev/null +++ b/configs/model/transformer_dbn_classifier.yaml @@ -0,0 +1,45 @@ +_target_: src.models.transformer_dbn_classifier.TransformerDBNClassifier + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +#################################################### +# compile model for faster training with pytorch 2.0 +compile: false +embedding_dim: 256 +dbn_after_each_layer: True +num_transformer_layers: 3 + +discrete_layer: + _target_: src.models.components.discrete_layers.vqvae.VQVAEDiscreteLayer + key: 'vqvae' + temperature: 1.0 + label_smoothing_scale: 0.0 + dist_ord: 2 + dictionary_dim: ${model.params.embedding_dim} + hard: True + projection_method: "layer norm" # "unit-sphere" "scale" "layer norm" or "None" + beta: 0.25 + +transformer_layer: + _target_: src.models.components.transformer.TransformerLayer + num_heads: 8 + dim_feedforward: ${model.params.embedding_dim} + dropout: 0.1 + activation: "relu" + dim: ${model.params.embedding_dim} + norm: "layer_norm" + batch_first: True + + + diff --git a/configs/train.yaml b/configs/train.yaml index ef7bdab..d32f78c 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -7,7 +7,7 @@ defaults: - data: mnist - model: mnist - callbacks: default - - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - trainer: default - paths: default - extras: default @@ -27,6 +27,10 @@ defaults: # debugging config (enable through command line, e.g. `python train.py debug=default) - debug: null +# determines the log directory's identifier +run_name: ??? +name: ??? + # task name, determines output directory path task_name: "train" @@ -46,4 +50,4 @@ test: True ckpt_path: null # seed for random number generators in pytorch, numpy and python.random -seed: null +seed: 42 diff --git a/environment.yaml b/environment.yaml index f74ee8c..101e4f7 100644 --- a/environment.yaml +++ b/environment.yaml @@ -5,7 +5,7 @@ # - conda allows for installing packages without requiring certain compilers or # libraries to be available in the system, since it installs precompiled binaries -name: myenv +name: ndr channels: - pytorch @@ -21,7 +21,7 @@ channels: # compatibility is usually guaranteed dependencies: - - python=3.10 + - python=3.11 - pytorch=2.* - torchvision=0.* - lightning=2.* @@ -32,7 +32,7 @@ dependencies: - pytest=7.* # --------- loggers --------- # - # - wandb + - wandb # - neptune-client # - mlflow # - comet-ml diff --git a/requirements.txt b/requirements.txt index d837268..ab16611 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 # --------- loggers --------- # -# wandb +wandb # neptune-client # mlflow # comet-ml diff --git a/src/models/components/discrete_layers/abstract_discrete_layer.py b/src/models/components/discrete_layers/abstract_discrete_layer.py new file mode 100644 index 0000000..3a49556 --- /dev/null +++ b/src/models/components/discrete_layers/abstract_discrete_layer.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +import torch.nn as nn +import torch +from torch.nn import LayerNorm +class AbstractDiscreteLayer(nn.Module): + def __init__(self, dims, **kwargs) -> None: + super().__init__() + self.input_dim = dims['input_dim'] # fed by the model, after x->z and z->x models are instantiated + self.output_dim = dims['output_dim'] # fed by the model, after x->z and z->x models are instantiated + self.vocab_size = dims['vocab_size'] + self.dictionary_dim = kwargs['dictionary_dim'] + + self.temperature = kwargs.get('temperature', 1) + self.label_smoothing_scale = kwargs.get('label_smoothing_scale', 0.001) + + self.out_layer_norm = LayerNorm(self.dictionary_dim) + + self.dictionary = nn.Embedding(self.vocab_size, self.dictionary_dim) + + self.output_embedding = nn.Linear(self.output_dim, self.dictionary_dim) + self.encoder_embedding = nn.Linear(self.dictionary_dim, self.input_dim) + self.decoder_embedding = nn.Linear(self.dictionary_dim, self.output_dim) + + def decoder_to_discrete_embedding(self, x): + out_x = self.output_embedding(x) + return out_x + + def discrete_embedding_to_decoder(self, x): + return self.decoder_embedding(x) + + def discrete_embedding_to_encoder(self, x): + return self.encoder_embedding(x) + + def project_matrix(self,x,**kwargs): + return x + + def project_embedding_matrix(self): + self.dictionary.weight = torch.nn.Parameter(self.project_matrix(self.dictionary.weight)) + + def forward(self, x,**kwargs): + continous_vector = self.decoder_to_discrete_embedding(x) + + # scores are between 0 and 1, and sum to 1 over the vocab dimension. + id, score, quantized_vector, quantization_loss = self.discretize(continous_vector,**kwargs) + return id, score, quantized_vector, quantization_loss + + def embed_enc_from_id(self, x): + embeds = self.dictionary(x) + return self.discrete_embedding_to_encoder(embeds) + + def embed_dec_from_id(self, x): + embeds = self.dictionary(x) + return self.discrete_embedding_to_decoder(embeds) + + @abstractmethod + def discretize(self, x,**kwargs) -> dict: + pass + + + \ No newline at end of file diff --git a/src/models/components/discrete_layers/gumbel.py b/src/models/components/discrete_layers/gumbel.py new file mode 100644 index 0000000..5307594 --- /dev/null +++ b/src/models/components/discrete_layers/gumbel.py @@ -0,0 +1,18 @@ +from .abstract_discrete_layer import AbstractDiscreteLayer +import torch +from torch.nn.functional import gumbel_softmax + + +class GumbelDiscreteLayer(AbstractDiscreteLayer): + def __init__(self, dims, **kwargs) -> None: + super().__init__(dims, **kwargs) + self.hard = kwargs['hard'] # if True, use argmax in forward pass, else use gumbel softmax. the backwardpass is the same in both cases + self.output_embedding = torch.nn.Linear(self.output_dim, self.vocab_size) + + def discretize(self, x,**kwargs) -> dict: + score = gumbel_softmax(x, tau=self.temperature, hard=self.hard, dim=-1) + x_quantized = torch.matmul(score, self.dictionary.weight) + id = torch.argmax(score, dim=-1) + quantization_loss = 0 + return id, score, x_quantized, quantization_loss + \ No newline at end of file diff --git a/src/models/components/discrete_layers/vqvae.py b/src/models/components/discrete_layers/vqvae.py new file mode 100644 index 0000000..d0453ac --- /dev/null +++ b/src/models/components/discrete_layers/vqvae.py @@ -0,0 +1,86 @@ +from .abstract_discrete_layer import AbstractDiscreteLayer +import torch +from torch import nn +# from vector_quantize_pytorch import VectorQuantize +from entmax import sparsemax + +class VQVAEDiscreteLayer(AbstractDiscreteLayer): + def __init__(self, dims, **kwargs) -> None: + super().__init__(dims, **kwargs) + + self.projection_method = kwargs.get("projection_method",None) + + self.dictionary = nn.Embedding(self.vocab_size, self.dictionary_dim) + self.dictionary.weight = torch.nn.Parameter(self.project_matrix(self.dictionary.weight)) + + self.dist_ord = kwargs.get('dist_ord', 2) + self.embedding_loss = torch.nn.functional.mse_loss # torch.nn.CosineSimilarity(dim=-1) + self.hard = kwargs['hard'] + self.kernel = nn.Softmax(dim=-1) + self.beta = kwargs.get("beta",0.25) #0.25 is the beta used in the vq-vae paper + + ################### + #Probably can remove these as we are using th matrix projection now + # def fetch_embeddings_by_index(self,indices): + # if self.normalize_embeddings: + # return nn.functional.normalize(self.dictionary(indices),dim=-1) + # #~else + # return self.dictionary(indices) + + # def fetch_embedding_matrix(self): + # if self.normalize_embeddings: + # return nn.functional.normalize(self.dictionary.weight,dim=-1) + # #~else + # return self.dictionary.weight + ################### + + def project_matrix(self,x): + if self.projection_method == "unit-sphere": + return torch.nn.functional.normalize(x,dim=-1) + if self.projection_method == "scale": + # devide the vector by the square root of the dimension + return x / torch.sqrt(self.dictionary_dim) + if self.projection_method == "layer norm": + return self.out_layer_norm(x) + return x + + def discretize(self, x, **kwargs) -> dict: + probs = self.kernel( - self.codebook_distances(x) / self.temperature) + x = self.project_matrix(x) + indices = torch.argmax(probs, dim=-1) + + if self.hard: + # Apply STE for hard quantization + quantized = self.dictionary(indices)#self.fetch_embeddings_by_index(indices) + quantized = quantized + x - (x).detach() + else: + quantized = torch.matmul(probs, self.dictionary.weight) + + if kwargs.get("supervision",False): + true_quantized = self.dictionary(kwargs.get("true_ids",None)) + commitment_loss = self.embedding_loss(true_quantized.detach(),x) + embedding_loss = self.embedding_loss(true_quantized,x.detach()) + + else: + commitment_loss = self.embedding_loss(quantized.detach(),x) + embedding_loss = self.embedding_loss(quantized,x.detach()) + + vq_loss = self.beta * commitment_loss + embedding_loss + + return indices, probs, quantized, vq_loss + + def codebook_distances(self, x): + + #dictionary_expanded = self.fetch_embedding_matrix().unsqueeze(0).unsqueeze(1) # Shape: (batch, 1, vocab, dim) + dictionary_expanded = self.dictionary.weight.unsqueeze(0).unsqueeze(1) + x_expanded = x.unsqueeze(2) + # if self.normalize_embeddings: + # x_expanded = nn.functional.normalize(x,dim=-1).unsqueeze(2) # Shape: (batch, length, 1, dim) + # else: + # x_expanded = x.unsqueeze(2) # Shape: (batch, length, 1, dim) + + # Compute the squared differences + dist = torch.linalg.vector_norm(x_expanded - dictionary_expanded, ord=self.dist_ord, dim=-1) + return dist + + \ No newline at end of file diff --git a/src/models/components/transformer.py b/src/models/components/transformer.py new file mode 100644 index 0000000..75750fb --- /dev/null +++ b/src/models/components/transformer.py @@ -0,0 +1,162 @@ +# code adapted from lucidrains/vit-pytorch/vit.py (https://github.com/lucidrains/vit-pytorch) +import torch +from torch import nn + +from einops import rearrange, repeat + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class TransformerLayer(nn.Module): + def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.attn = PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)) + self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + + def forward(self, x): + x = self.attn(x) + x + x = self.ff(x) + x + return x + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + +class TransformerDBN(nn.Module): + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + self.token_embedding = nn.Embedding(self.hparams.num_embedding, embedding_dim=self.hparams.embedding_dim) + self.pos_embedding = nn.Parameter(torch.randn(1, seq_len + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + self.albert = albert + self.layers = nn.ModuleList([]) + + for i in range(hparams.depth): + self.layers.append() + + + + self.model = nn.Sequential() + # Add embedding layer + self.embedding = nn.Embedding(hparams.num_embeddings, hparams.embedding_dim) + + + transformer_layer = hydra.utils.instantiate(self.hparams.transformer_layer) + self.model.add_module(transformer_layer) + + if self.hparams.dbn_after_each_layer is not None or i == hparams.depth - 1: + discrete_layer = hydra.utils.instantiate(hparams.discrete_layer) + if self.hparam.shared_embedding_dbn: + discrete_layer.dictionary = self.embedding + self.model.add_module(discrete_layer) + +class TokenTransformer(nn.Module): + def __init__(self, *, seq_len, output_dim, dim, depth, heads, mlp_dim, pool='cls', + dim_head=64, dropout=0., emb_dropout=0., inputs_are_pos_neg_ones=True, albert=False): + super().__init__() + + self._inputs_are_pos_neg_ones = inputs_are_pos_neg_ones + self.token_embedding = nn.Embedding(num_embeddings=2, embedding_dim=dim) + self.pos_embedding = nn.Parameter(torch.randn(1, seq_len + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + self.albert = albert + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, output_dim) + ) + + def forward(self, inputs): + if self._inputs_are_pos_neg_ones: + # convert from +1/-1 to 0/1 + inputs = (inputs + 1) / 2 + inputs = inputs.int() + x = self.token_embedding(inputs) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x) \ No newline at end of file diff --git a/src/models/transformer_dbn_classifier.py b/src/models/transformer_dbn_classifier.py new file mode 100644 index 0000000..abe9d55 --- /dev/null +++ b/src/models/transformer_dbn_classifier.py @@ -0,0 +1,213 @@ +from typing import Any, Dict, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import Accuracy + +class TransformerDBNClassifier(LightningModule): + """transformer with discrete bottleneck layer lightning module + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.__init_model() + + # loss function + self.criterion = torch.nn.CrossEntropyLoss() + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = Accuracy(task="multiclass", num_classes=10) + self.val_acc = Accuracy(task="multiclass", num_classes=10) + self.test_acc = Accuracy(task="multiclass", num_classes=10) + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + + + def __init_model(self): + self.model = nn.Sequential() + for _ in range(self.hparams.depth): + transformer_layer = hydra.utils.instantiate(self.hparams.transformer_layer) + self.model.add_module(transformer_layer) + if self.hparams.discrete_layer is not None: + discrete_layer = hydra.utils.instantiate(self.hparams.discrete_layer) + self.model.add_module(discrete_layer) + + + self.model = nn.Sequential() + # Add embedding layer + self.embedding = nn.Embedding(hparams.num_embeddings, hparams.embedding_dim) + + for i in range(hparams.depth): + transformer_layer = hydra.utils.instantiate(self.hparams.transformer_layer) + self.model.add_module(transformer_layer) + + if self.hparams.dbn_after_each_layer is not None or i == hparams.depth - 1: + discrete_layer = hydra.utils.instantiate(hparams.discrete_layer) + if self.hparam.shared_embedding_dbn: + discrete_layer.dictionary = self.embedding + self.model.add_module(discrete_layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.model(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + """ + x, y = batch + logits = self.forward(x) + loss = self.criterion(logits, y) + preds = torch.argmax(logits, dim=1) + return loss, preds, y + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.train_loss(loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + + # return loss or backpropagation will fail + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + pass + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.test_loss(loss) + self.test_acc(preds, targets) + self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = MNISTLitModule(None, None, None, None)